feat(wan): Add Wan2.2 I2V support
This commit is contained in:
265
docs/wan22-implementation-notes.md
Normal file
265
docs/wan22-implementation-notes.md
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# Wan2.2 MLX Implementation Notes
|
||||||
|
|
||||||
|
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / T2V-1.3B) to Apple MLX.
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early reports, the T2V/TI2V models do **not** use Mixture-of-Experts — they are dense DiT models with a dual-model architecture for the 14B variant (separate high-noise and low-noise denoisers with a boundary timestep).
|
||||||
|
|
||||||
|
### Key Parameters
|
||||||
|
|
||||||
|
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride |
|
||||||
|
|-------|-----|-------|--------|----------|-----------|------------|
|
||||||
|
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) |
|
||||||
|
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) |
|
||||||
|
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) |
|
||||||
|
|
||||||
|
### Codebase Structure (~3900 lines of Wan2.2 code)
|
||||||
|
|
||||||
|
```
|
||||||
|
mlx_video/
|
||||||
|
├── generate_wan.py # 483L - Generation pipeline (T2V + I2V)
|
||||||
|
├── convert_wan.py # 564L - Weight conversion from HuggingFace
|
||||||
|
└── models/wan/
|
||||||
|
├── config.py # 113L - Model configs (dataclass presets)
|
||||||
|
├── model.py # 320L - DiT model (time embed, patchify, unpatchify)
|
||||||
|
├── transformer.py # 91L - Attention block + FFN
|
||||||
|
├── attention.py # 211L - Self-attention + cross-attention
|
||||||
|
├── rope.py # 100L - 3D Rotary Position Embeddings
|
||||||
|
├── text_encoder.py # 240L - T5 encoder (UMT5-XXL)
|
||||||
|
├── scheduler.py # 428L - Euler, DPM++ 2M, UniPC schedulers
|
||||||
|
├── vae.py # 315L - Wan2.1 VAE decoder (4×8×8)
|
||||||
|
├── vae22.py # 836L - Wan2.2 VAE encoder + decoder (4×16×16)
|
||||||
|
├── loading.py # 154L - Model loading utilities
|
||||||
|
└── i2v_utils.py # 58L - I2V mask/preprocessing
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Critical Bugs & Fixes
|
||||||
|
|
||||||
|
### 1. MLX Underscore Attribute Gotcha
|
||||||
|
|
||||||
|
**Problem**: MLX's `nn.Module` silently ignores underscore-prefixed attributes (`_layer_0`, `_layer_1`, etc.) in `parameters()` and `load_weights()`. The Wan2.2 VAE had layers named `_layer_N`, causing **87 out of 110 weights to be silently dropped** during loading.
|
||||||
|
|
||||||
|
**Fix**: Rename all `_layer_N` attributes to `layer_N`. MLX treats underscore-prefixed attributes as "private" and excludes them from the parameter tree.
|
||||||
|
|
||||||
|
**Lesson**: Never use underscore-prefixed names for `nn.Module` sub-modules in MLX.
|
||||||
|
|
||||||
|
### 2. Patchify Channel Ordering
|
||||||
|
|
||||||
|
**Problem**: The patchify/unpatchify operations transposed channels incorrectly — producing `[C fastest]` layout instead of `[C slowest]`, causing completely garbled video output.
|
||||||
|
|
||||||
|
**Fix**: Changed reshape to produce correct `[B, T', H', W', pt*ph*pw*C]` ordering matching PyTorch's contiguous memory layout.
|
||||||
|
|
||||||
|
**Lesson**: When porting PyTorch reshape/view operations to MLX, pay close attention to memory layout — PyTorch is row-major by default, and reshape semantics differ when dimensions are reordered.
|
||||||
|
|
||||||
|
### 3. VAE AttentionBlock Reshape
|
||||||
|
|
||||||
|
**Problem**: Attention block merged batch (B) with channels (C) instead of batch with temporal (T), producing a green checker pattern in output.
|
||||||
|
|
||||||
|
**Fix**: Correct reshape from `[B*C, T, H, W]` to `[B*T, C, H, W]` for spatial attention.
|
||||||
|
|
||||||
|
### 4. RMS Norm vs L2 Norm
|
||||||
|
|
||||||
|
**Problem**: The Wan2.2 VAE uses a class named `RMS_norm` in PyTorch, but it actually computes **L2 normalization** (divide by L2 norm), not RMS normalization (divide by RMS). Using actual RMS norm caused exponential value explosion.
|
||||||
|
|
||||||
|
**Fix**: Implement as `x / ||x||₂` instead of `x / sqrt(mean(x²))`.
|
||||||
|
|
||||||
|
**Lesson**: Don't trust class names in reference code — read the actual computation.
|
||||||
|
|
||||||
|
### 5. Video Codec Green Output
|
||||||
|
|
||||||
|
**Problem**: OpenCV's `mp4v` codec on macOS produces green-tinted video.
|
||||||
|
|
||||||
|
**Fix**: Switch to `imageio` with `libx264` codec. Fallback chain: imageio → cv2 (avc1) → PNG frames.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Precision & Dtype Flow
|
||||||
|
|
||||||
|
### The bfloat16 Autocast Pattern
|
||||||
|
|
||||||
|
The official PyTorch implementation uses `torch.autocast("cuda", dtype=torch.bfloat16)` which automatically casts matmul inputs. In MLX, we replicate this manually:
|
||||||
|
|
||||||
|
| Operation | Official (PyTorch) | MLX Implementation |
|
||||||
|
|---|---|---|
|
||||||
|
| Modulation/gates | float32 (explicit `autocast(enabled=False)`) | `x.astype(mx.float32)` before modulation |
|
||||||
|
| QKV projections | bfloat16 (outer autocast) | Cast input to `self.q.weight.dtype` |
|
||||||
|
| RoPE computation | float64 → float32 | float32 (MLX lacks float64 on GPU) |
|
||||||
|
| Q/K after RoPE | bfloat16 (`q.to(v.dtype)`) | Cast back to weight dtype after RoPE |
|
||||||
|
| FFN matmuls | bfloat16 (outer autocast) | Cast input to `self.fc1.weight.dtype` |
|
||||||
|
| Residual stream | float32 | float32 (no cast) |
|
||||||
|
|
||||||
|
**Result**: ~16% speedup (47s vs 56s for 20 steps at 480p) with no quality regression.
|
||||||
|
|
||||||
|
**Key insight**: Modulation parameters (scale, shift, gate) must stay in float32 — they are small values (~0.01–0.1) that lose significant precision in bfloat16. The official code explicitly disables autocast for these computations.
|
||||||
|
|
||||||
|
### T5 Encoder Precision
|
||||||
|
|
||||||
|
The T5 text encoder must run in float32. Bfloat16 weights cause the attention softmax to produce degenerate distributions, which corrupts text conditioning and manifests as blurry patches in generated video. Since T5 only runs once per generation, the performance cost is negligible.
|
||||||
|
|
||||||
|
### VAE Decoder Precision
|
||||||
|
|
||||||
|
VAE weights must be float32. Bfloat16 VAE decode introduces visible quality loss in the decoded video frames.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Scheduler Implementation Details
|
||||||
|
|
||||||
|
### Three Schedulers: Euler, DPM++ 2M, UniPC
|
||||||
|
|
||||||
|
All operate in the flow-matching formulation where `sigma` represents the noise level (1.0 = pure noise, 0.0 = clean).
|
||||||
|
|
||||||
|
**Euler**: Simple first-order ODE solver. Most stable, recommended for debugging.
|
||||||
|
|
||||||
|
**DPM++ 2M**: Second-order multistep solver. Uses previous step's model output for higher-order correction. Requires special handling at boundaries (return `±inf` from `_lambda()` when sigma is 0 or 1).
|
||||||
|
|
||||||
|
**UniPC** (default, matches official): Second-order predictor-corrector. The "C" (corrector) part is critical — it refines each step using the already-computed model output at **zero additional model evaluation cost**.
|
||||||
|
|
||||||
|
### UniPC Corrector: Must Be Enabled
|
||||||
|
|
||||||
|
**Discovery**: Our implementation had `use_corrector=False` by default, but the official Wan2.2 code **always** enables it (there's no flag — the corrector runs whenever `step_index > 0`).
|
||||||
|
|
||||||
|
**Impact**: Without the corrector, UniPC degrades to a simple predictor, losing its second-order accuracy advantage.
|
||||||
|
|
||||||
|
### UniPC Corrector Coefficients
|
||||||
|
|
||||||
|
The corrector coefficients (`rhos_c`) must be computed by solving a linear system, not hardcoded. For order ≥ 2, hardcoding `rhos_c[-1] = 0.5` introduces ~6–13% error in the correction term across 47+ steps. The fix uses `np.linalg.solve()` to compute exact coefficients.
|
||||||
|
|
||||||
|
### Sigma Schedule
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Flow-matching sigma schedule with shift
|
||||||
|
sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps)
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
```
|
||||||
|
|
||||||
|
Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Image-to-Video (I2V) Pipeline
|
||||||
|
|
||||||
|
### Per-Token Timesteps
|
||||||
|
|
||||||
|
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# mask_tokens: [1, L] — 0 for first-frame patches, 1 for rest
|
||||||
|
t_tokens = mask_tokens * current_timestep # first-frame → t=0
|
||||||
|
```
|
||||||
|
|
||||||
|
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
|
||||||
|
|
||||||
|
### Mask Re-application
|
||||||
|
|
||||||
|
After each scheduler step, the first-frame latent is re-injected to prevent drift:
|
||||||
|
|
||||||
|
```python
|
||||||
|
latents = (1.0 - mask) * z_img + mask * latents
|
||||||
|
```
|
||||||
|
|
||||||
|
### VAE Encoder Temporal Downsample Order
|
||||||
|
|
||||||
|
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
|
||||||
|
- Stage 0: Spatial-only downsampling
|
||||||
|
- Stages 1–2: Spatial + temporal downsampling
|
||||||
|
|
||||||
|
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dimension Constraints
|
||||||
|
|
||||||
|
### Patchify Alignment
|
||||||
|
|
||||||
|
Video dimensions must be divisible by `patch_size × vae_stride`:
|
||||||
|
- **TI2V-5B**: patch=(1,2,2), stride=(4,16,16) → alignment = **32** pixels
|
||||||
|
- **T2V-14B**: patch=(1,2,2), stride=(4,8,8) → alignment = **16** pixels
|
||||||
|
|
||||||
|
Example: 720p (1280×720) → 720 % 32 ≠ 0, auto-aligns to **704**.
|
||||||
|
|
||||||
|
### Frame Count
|
||||||
|
|
||||||
|
Frames must satisfy `num_frames = 4n + 1` (e.g., 5, 9, 13, ..., 81) due to temporal VAE stride of 4.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Performance Optimizations
|
||||||
|
|
||||||
|
### Batched CFG
|
||||||
|
|
||||||
|
Instead of two separate forward passes for conditional and unconditional predictions, batch them into a single B=2 forward pass:
|
||||||
|
|
||||||
|
```python
|
||||||
|
preds = model([latents, latents], t=t_batch, context=context_cfg, ...)
|
||||||
|
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Result**: ~40% speedup by amortizing attention overhead.
|
||||||
|
|
||||||
|
### Precomputed Text Embeddings & Cross-Attention KV Cache
|
||||||
|
|
||||||
|
Text embeddings and cross-attention K/V projections are constant across all diffusion steps. Computing them once and passing as caches eliminates redundant computation.
|
||||||
|
|
||||||
|
### Memory Management in Diffusion Loop
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Release temporaries before eval to free memory for graph execution
|
||||||
|
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||||
|
mx.eval(latents)
|
||||||
|
```
|
||||||
|
|
||||||
|
MLX's lazy evaluation means `mx.eval()` triggers the full computation graph. Deleting intermediate arrays before eval allows MLX to reuse their memory during execution.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Weight Conversion
|
||||||
|
|
||||||
|
### Key Mapping Patterns
|
||||||
|
|
||||||
|
The PyTorch → MLX conversion (`convert_wan.py`) handles several systematic transforms:
|
||||||
|
|
||||||
|
1. **Conv3d weight transposition**: PyTorch `(out, in, D, H, W)` → MLX `(out, D, H, W, in)`
|
||||||
|
2. **Linear weight transposition**: PyTorch `(out, in)` → MLX `(out, in)` (same convention for `nn.Linear`)
|
||||||
|
3. **Nested module paths**: `blocks.0.self_attn.q.weight` → same paths, MLX loads by dotted key
|
||||||
|
|
||||||
|
### Dual-Model Splitting
|
||||||
|
|
||||||
|
The T2V-14B uses dual models (high-noise and low-noise). The conversion script splits a single checkpoint into separate files or handles pre-split checkpoints from HuggingFace.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
260 tests across 9 files, all running in ~4 seconds:
|
||||||
|
|
||||||
|
| File | Focus |
|
||||||
|
|------|-------|
|
||||||
|
| test_wan_config.py | Config presets, field validation |
|
||||||
|
| test_wan_attention.py | Self/cross attention, RMSNorm, bf16 autocast |
|
||||||
|
| test_wan_transformer.py | FFN, attention block, float32 modulation |
|
||||||
|
| test_wan_model.py | Full DiT forward pass, per-token timesteps |
|
||||||
|
| test_wan_t5.py | T5 encoder layers and full encoding |
|
||||||
|
| test_wan_vae.py | VAE 2.1 decoder, VAE 2.2 encoder + decoder |
|
||||||
|
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
|
||||||
|
| test_wan_convert.py | Weight sanitization and conversion |
|
||||||
|
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
|
||||||
|
|
||||||
|
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
|
||||||
|
### I2V Quality Degradation
|
||||||
|
|
||||||
|
Frames 2–13 gradually degrade, and frame 14 often has a "flash" artifact. All implementation details have been verified against the official PyTorch code with no discrepancies found. Possible causes:
|
||||||
|
- Subtle numerical differences from float32 vs float64 RoPE (MLX lacks float64 on GPU)
|
||||||
|
- MLX-specific attention precision behavior
|
||||||
|
- Better prompts and 720p resolution (the model's native resolution) help reduce artifacts
|
||||||
|
|
||||||
|
### Chinese Negative Prompt
|
||||||
|
|
||||||
|
The official Wan2.2 uses a Chinese negative prompt that prevents oversaturation and comic-style artifacts. Correct tokenization requires `ftfy.fix_text()` to normalize fullwidth characters and double HTML unescaping. Without proper text cleaning, the negative prompt tokens don't match the training distribution, causing blurry patches.
|
||||||
@@ -338,6 +338,10 @@ def convert_wan_checkpoint(
|
|||||||
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||||
f"heads={src_num_heads}, type={src_model_type}")
|
f"heads={src_num_heads}, type={src_model_type}")
|
||||||
|
|
||||||
|
# Use preset for known TI2V 5B configuration
|
||||||
|
if src_model_type == "ti2v" and src_dim == 3072:
|
||||||
|
return WanModelConfig.wan22_ti2v_5b()
|
||||||
|
|
||||||
is_22 = model_version == "2.2"
|
is_22 = model_version == "2.2"
|
||||||
|
|
||||||
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
||||||
@@ -409,7 +413,8 @@ def convert_wan_checkpoint(
|
|||||||
weights = load_torch_weights(str(vae_path))
|
weights = load_torch_weights(str(vae_path))
|
||||||
if is_wan22_vae:
|
if is_wan22_vae:
|
||||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
weights = sanitize_wan22_vae_weights(weights)
|
include_encoder = config.model_type == "ti2v"
|
||||||
|
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
|
||||||
else:
|
else:
|
||||||
weights = sanitize_wan_vae_weights(weights)
|
weights = sanitize_wan_vae_weights(weights)
|
||||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||||
|
|||||||
@@ -9,17 +9,7 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
# ANSI color codes
|
from mlx_video.utils import Colors
|
||||||
class Colors:
|
|
||||||
CYAN = "\033[96m"
|
|
||||||
BLUE = "\033[94m"
|
|
||||||
GREEN = "\033[92m"
|
|
||||||
YELLOW = "\033[93m"
|
|
||||||
RED = "\033[91m"
|
|
||||||
MAGENTA = "\033[95m"
|
|
||||||
BOLD = "\033[1m"
|
|
||||||
DIM = "\033[2m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||||
from mlx_video.models.ltx.ltx import LTXModel
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
|
|||||||
@@ -13,156 +13,27 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
||||||
|
from mlx_video.models.wan.loading import (
|
||||||
|
_clean_text,
|
||||||
|
encode_text,
|
||||||
|
load_t5_encoder,
|
||||||
|
load_vae_decoder,
|
||||||
|
load_vae_encoder,
|
||||||
|
load_wan_model,
|
||||||
|
)
|
||||||
|
from mlx_video.postprocess import save_video
|
||||||
|
from mlx_video.utils import Colors
|
||||||
|
|
||||||
class Colors:
|
# Backward-compat alias (tests and external code may use the old name)
|
||||||
CYAN = "\033[96m"
|
_build_i2v_mask = build_i2v_mask
|
||||||
BLUE = "\033[94m"
|
|
||||||
GREEN = "\033[92m"
|
|
||||||
YELLOW = "\033[93m"
|
|
||||||
RED = "\033[91m"
|
|
||||||
MAGENTA = "\033[95m"
|
|
||||||
BOLD = "\033[1m"
|
|
||||||
DIM = "\033[2m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
|
|
||||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
|
||||||
"""Load and initialize WanModel, with optional quantization support.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model safetensors file
|
|
||||||
config: WanModelConfig
|
|
||||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
|
||||||
If provided, creates QuantizedLinear stubs before loading.
|
|
||||||
"""
|
|
||||||
from mlx_video.models.wan.model import WanModel
|
|
||||||
|
|
||||||
model = WanModel(config)
|
|
||||||
|
|
||||||
if quantization:
|
|
||||||
from mlx_video.convert_wan import _quantize_predicate
|
|
||||||
|
|
||||||
nn.quantize(
|
|
||||||
model,
|
|
||||||
group_size=quantization["group_size"],
|
|
||||||
bits=quantization["bits"],
|
|
||||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = mx.load(str(model_path))
|
|
||||||
model.load_weights(list(weights.items()), strict=False)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_t5_encoder(model_path: Path, config):
|
|
||||||
"""Load T5 text encoder.
|
|
||||||
|
|
||||||
Weights are upcast to float32 for maximum precision — the T5 encoder
|
|
||||||
only runs once per generation, so performance impact is negligible.
|
|
||||||
This matches the official which computes softmax in float32 explicitly.
|
|
||||||
"""
|
|
||||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
|
||||||
|
|
||||||
encoder = T5Encoder(
|
|
||||||
vocab_size=config.t5_vocab_size,
|
|
||||||
dim=config.t5_dim,
|
|
||||||
dim_attn=config.t5_dim_attn,
|
|
||||||
dim_ffn=config.t5_dim_ffn,
|
|
||||||
num_heads=config.t5_num_heads,
|
|
||||||
num_layers=config.t5_num_layers,
|
|
||||||
num_buckets=config.t5_num_buckets,
|
|
||||||
shared_pos=False,
|
|
||||||
)
|
|
||||||
weights = mx.load(str(model_path))
|
|
||||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
|
||||||
encoder.load_weights(list(weights.items()))
|
|
||||||
mx.eval(encoder.parameters())
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def load_vae_decoder(model_path: Path, config=None):
|
|
||||||
"""Load VAE decoder (skips encoder weights with strict=False).
|
|
||||||
|
|
||||||
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
|
|
||||||
For Wan2.1 (vae_z_dim=16), uses WanVAE.
|
|
||||||
"""
|
|
||||||
is_wan22 = config is not None and config.vae_z_dim == 48
|
|
||||||
|
|
||||||
if is_wan22:
|
|
||||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
|
||||||
vae = Wan22VAEDecoder(z_dim=48)
|
|
||||||
else:
|
|
||||||
from mlx_video.models.wan.vae import WanVAE
|
|
||||||
vae = WanVAE(z_dim=16)
|
|
||||||
|
|
||||||
weights = mx.load(str(model_path))
|
|
||||||
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
|
||||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
|
||||||
vae.load_weights(list(weights.items()), strict=False)
|
|
||||||
mx.eval(vae.parameters())
|
|
||||||
return vae
|
|
||||||
|
|
||||||
|
|
||||||
def _clean_text(text: str) -> str:
|
|
||||||
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
|
||||||
|
|
||||||
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
|
||||||
double HTML unescape, and whitespace normalization. Critical for
|
|
||||||
correct tokenization of the Chinese negative prompt.
|
|
||||||
"""
|
|
||||||
import html
|
|
||||||
import re
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ftfy
|
|
||||||
text = ftfy.fix_text(text)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
text = html.unescape(html.unescape(text))
|
|
||||||
text = re.sub(r"\s+", " ", text).strip()
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def encode_text(
|
|
||||||
encoder,
|
|
||||||
tokenizer,
|
|
||||||
prompt: str,
|
|
||||||
text_len: int = 512,
|
|
||||||
) -> mx.array:
|
|
||||||
"""Encode text prompt using T5 encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoder: T5Encoder model
|
|
||||||
tokenizer: HuggingFace tokenizer
|
|
||||||
prompt: Text prompt
|
|
||||||
text_len: Maximum text length
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Text embeddings [L, dim]
|
|
||||||
"""
|
|
||||||
prompt = _clean_text(prompt)
|
|
||||||
tokens = tokenizer(
|
|
||||||
prompt,
|
|
||||||
max_length=text_len,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="np",
|
|
||||||
)
|
|
||||||
ids = mx.array(tokens["input_ids"])
|
|
||||||
mask = mx.array(tokens["attention_mask"])
|
|
||||||
|
|
||||||
embeddings = encoder(ids, mask=mask)
|
|
||||||
|
|
||||||
# Return only non-padding tokens
|
|
||||||
seq_len = int(mask.sum().item())
|
|
||||||
return embeddings[0, :seq_len]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
model_dir: str,
|
model_dir: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str | None = None,
|
negative_prompt: str | None = None,
|
||||||
|
image: str | None = None,
|
||||||
width: int = 1280,
|
width: int = 1280,
|
||||||
height: int = 720,
|
height: int = 720,
|
||||||
num_frames: int = 81,
|
num_frames: int = 81,
|
||||||
@@ -173,12 +44,13 @@ def generate_video(
|
|||||||
output_path: str = "output.mp4",
|
output_path: str = "output.mp4",
|
||||||
scheduler: str = "unipc",
|
scheduler: str = "unipc",
|
||||||
):
|
):
|
||||||
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
|
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_dir: Path to converted MLX model directory
|
model_dir: Path to converted MLX model directory
|
||||||
prompt: Text prompt
|
prompt: Text prompt
|
||||||
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
||||||
|
image: Path to input image for I2V (None = T2V mode)
|
||||||
width: Video width
|
width: Video width
|
||||||
height: Video height
|
height: Video height
|
||||||
num_frames: Number of frames (must be 4n+1)
|
num_frames: Number of frames (must be 4n+1)
|
||||||
@@ -240,6 +112,7 @@ def generate_video(
|
|||||||
config = WanModelConfig.wan21_t2v_14b()
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
|
||||||
is_dual = config.dual_model
|
is_dual = config.dual_model
|
||||||
|
is_i2v = image is not None
|
||||||
|
|
||||||
# Validate config against actual weights (handles mismatched config.json)
|
# Validate config against actual weights (handles mismatched config.json)
|
||||||
if not is_dual:
|
if not is_dual:
|
||||||
@@ -288,6 +161,7 @@ def generate_video(
|
|||||||
|
|
||||||
version_str = f"Wan{config.model_version}"
|
version_str = f"Wan{config.model_version}"
|
||||||
mode_str = "dual-model" if is_dual else "single-model"
|
mode_str = "dual-model" if is_dual else "single-model"
|
||||||
|
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
|
||||||
# Resolve negative prompt: explicit user value > config default
|
# Resolve negative prompt: explicit user value > config default
|
||||||
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
||||||
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
||||||
@@ -297,9 +171,11 @@ def generate_video(
|
|||||||
else:
|
else:
|
||||||
neg_prompt_resolved = negative_prompt
|
neg_prompt_resolved = negative_prompt
|
||||||
print(f"{Colors.CYAN}{'='*60}")
|
print(f"{Colors.CYAN}{'='*60}")
|
||||||
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
|
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
|
||||||
print(f"{'='*60}{Colors.RESET}")
|
print(f"{'='*60}{Colors.RESET}")
|
||||||
print(f"{Colors.DIM} Prompt: {prompt}")
|
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||||
|
if is_i2v:
|
||||||
|
print(f" Image: {image}")
|
||||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
||||||
print(f" Neg prompt: {neg_display}")
|
print(f" Neg prompt: {neg_display}")
|
||||||
@@ -314,8 +190,22 @@ def generate_video(
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
|
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
|
||||||
|
|
||||||
# Compute target latent shape
|
# Align dimensions to patch_size * vae_stride (required for patchify)
|
||||||
vae_stride = config.vae_stride
|
vae_stride = config.vae_stride
|
||||||
|
patch_size = config.patch_size
|
||||||
|
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
|
||||||
|
align_w = patch_size[2] * vae_stride[2]
|
||||||
|
if height % align_h != 0 or width % align_w != 0:
|
||||||
|
old_h, old_w = height, width
|
||||||
|
height = (height // align_h) * align_h
|
||||||
|
width = (width // align_w) * align_w
|
||||||
|
if height == 0:
|
||||||
|
height = align_h
|
||||||
|
if width == 0:
|
||||||
|
width = align_w
|
||||||
|
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
||||||
|
|
||||||
|
# Compute target latent shape
|
||||||
z_dim = config.vae_z_dim
|
z_dim = config.vae_z_dim
|
||||||
t_latent = (num_frames - 1) // vae_stride[0] + 1
|
t_latent = (num_frames - 1) // vae_stride[0] + 1
|
||||||
h_latent = height // vae_stride[1]
|
h_latent = height // vae_stride[1]
|
||||||
@@ -323,7 +213,6 @@ def generate_video(
|
|||||||
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
||||||
|
|
||||||
# Sequence length for transformer
|
# Sequence length for transformer
|
||||||
patch_size = config.patch_size
|
|
||||||
seq_len = math.ceil(
|
seq_len = math.ceil(
|
||||||
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
|
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
|
||||||
)
|
)
|
||||||
@@ -352,6 +241,31 @@ def generate_video(
|
|||||||
gc.collect(); mx.clear_cache()
|
gc.collect(); mx.clear_cache()
|
||||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
# I2V: encode image to latent space
|
||||||
|
z_img = None
|
||||||
|
i2v_mask = None
|
||||||
|
i2v_mask_tokens = None
|
||||||
|
if is_i2v:
|
||||||
|
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
|
||||||
|
t_img = time.time()
|
||||||
|
img_tensor = preprocess_image(image, width, height)
|
||||||
|
mx.eval(img_tensor)
|
||||||
|
|
||||||
|
vae_path = model_dir / "vae.safetensors"
|
||||||
|
vae_enc = load_vae_encoder(vae_path, config)
|
||||||
|
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||||
|
mx.eval(z_img)
|
||||||
|
|
||||||
|
# Convert to channels-first: [z_dim, 1, H_lat, W_lat]
|
||||||
|
z_img = z_img[0].transpose(3, 0, 1, 2)
|
||||||
|
|
||||||
|
# Build I2V mask
|
||||||
|
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||||
|
|
||||||
|
del vae_enc, img_tensor
|
||||||
|
gc.collect(); mx.clear_cache()
|
||||||
|
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
# Load transformer models
|
# Load transformer models
|
||||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||||
if quantization:
|
if quantization:
|
||||||
@@ -398,12 +312,18 @@ def generate_video(
|
|||||||
# Generate initial noise
|
# Generate initial noise
|
||||||
noise = mx.random.normal(target_shape)
|
noise = mx.random.normal(target_shape)
|
||||||
|
|
||||||
|
# I2V: blend first-frame latent into noise
|
||||||
|
if is_i2v:
|
||||||
|
# Broadcast z_img [z_dim, 1, H, W] across T for first-frame conditioning
|
||||||
|
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
|
||||||
|
else:
|
||||||
|
latents = noise
|
||||||
|
|
||||||
# Boundary for model switching (dual model only)
|
# Boundary for model switching (dual model only)
|
||||||
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
|
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
|
||||||
|
|
||||||
# Diffusion loop
|
# Diffusion loop
|
||||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||||
latents = noise
|
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||||
@@ -424,10 +344,24 @@ def generate_video(
|
|||||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||||
kv = cross_kv
|
kv = cross_kv
|
||||||
|
|
||||||
|
# Build per-token timesteps for I2V (first-frame patches get t=0)
|
||||||
|
if is_i2v:
|
||||||
|
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
|
||||||
|
# Pad to seq_len if needed
|
||||||
|
pad_len = seq_len - t_tokens.shape[1]
|
||||||
|
if pad_len > 0:
|
||||||
|
t_tokens = mx.concatenate(
|
||||||
|
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||||
|
)
|
||||||
|
# Batch for CFG: both cond and uncond get same timesteps
|
||||||
|
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L]
|
||||||
|
else:
|
||||||
|
t_batch = mx.array([timestep_val, timestep_val])
|
||||||
|
|
||||||
# CFG: batch cond + uncond into single B=2 forward pass
|
# CFG: batch cond + uncond into single B=2 forward pass
|
||||||
preds = model(
|
preds = model(
|
||||||
[latents, latents],
|
[latents, latents],
|
||||||
t=mx.array([timestep_val, timestep_val]),
|
t=t_batch,
|
||||||
context=context_cfg,
|
context=context_cfg,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
cross_kv_caches=kv,
|
cross_kv_caches=kv,
|
||||||
@@ -438,6 +372,10 @@ def generate_video(
|
|||||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||||
|
|
||||||
|
# I2V: re-apply mask to keep first frame frozen
|
||||||
|
if is_i2v:
|
||||||
|
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||||
|
|
||||||
# Release temporaries before eval to free memory for graph execution
|
# Release temporaries before eval to free memory for graph execution
|
||||||
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
@@ -488,43 +426,12 @@ def generate_video(
|
|||||||
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||||
|
|
||||||
|
|
||||||
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
|
||||||
"""Save video frames to MP4.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frames: Video frames [T, H, W, 3] uint8
|
|
||||||
output_path: Output file path
|
|
||||||
fps: Frames per second
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import imageio
|
|
||||||
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
|
||||||
for frame in frames:
|
|
||||||
writer.append_data(frame)
|
|
||||||
writer.close()
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
import cv2
|
|
||||||
h, w = frames.shape[1], frames.shape[2]
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
|
||||||
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
|
||||||
for frame in frames:
|
|
||||||
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
|
||||||
writer.release()
|
|
||||||
except (ImportError, Exception):
|
|
||||||
# Last resort: save as individual PNGs
|
|
||||||
from PIL import Image
|
|
||||||
out_dir = Path(output_path).parent / Path(output_path).stem
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
for i, frame in enumerate(frames):
|
|
||||||
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
|
||||||
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||||
|
parser.add_argument("--image", type=str, default=None,
|
||||||
|
help="Path to input image for I2V (omit for T2V mode)")
|
||||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
parser.add_argument("--negative-prompt", type=str, default=None,
|
||||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||||
@@ -559,6 +466,7 @@ def main():
|
|||||||
model_dir=args.model_dir,
|
model_dir=args.model_dir,
|
||||||
prompt=args.prompt,
|
prompt=args.prompt,
|
||||||
negative_prompt=neg_prompt,
|
negative_prompt=neg_prompt,
|
||||||
|
image=args.image,
|
||||||
width=args.width,
|
width=args.width,
|
||||||
height=args.height,
|
height=args.height,
|
||||||
num_frames=args.num_frames,
|
num_frames=args.num_frames,
|
||||||
|
|||||||
@@ -71,8 +71,12 @@ class WanSelfAttention(nn.Module):
|
|||||||
b, s, _ = x.shape
|
b, s, _ = x.shape
|
||||||
n, d = self.num_heads, self.head_dim
|
n, d = self.num_heads, self.head_dim
|
||||||
|
|
||||||
q = self.q(x)
|
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||||
k = self.k(x)
|
w_dtype = self.q.weight.dtype
|
||||||
|
x_w = x.astype(w_dtype)
|
||||||
|
|
||||||
|
q = self.q(x_w)
|
||||||
|
k = self.k(x_w)
|
||||||
if self.norm_q is not None:
|
if self.norm_q is not None:
|
||||||
q = self.norm_q(q)
|
q = self.norm_q(q)
|
||||||
if self.norm_k is not None:
|
if self.norm_k is not None:
|
||||||
@@ -80,15 +84,15 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
q = q.reshape(b, s, n, d)
|
q = q.reshape(b, s, n, d)
|
||||||
k = k.reshape(b, s, n, d)
|
k = k.reshape(b, s, n, d)
|
||||||
v = self.v(x).reshape(b, s, n, d)
|
v = self.v(x_w).reshape(b, s, n, d)
|
||||||
|
|
||||||
# Apply RoPE
|
# RoPE in float32 for precision (official uses float64)
|
||||||
q = rope_apply(q, grid_sizes, freqs)
|
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
|
||||||
k = rope_apply(k, grid_sizes, freqs)
|
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
|
||||||
|
|
||||||
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
|
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||||
q = q.transpose(0, 2, 1, 3)
|
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||||
k = k.transpose(0, 2, 1, 3)
|
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||||
v = v.transpose(0, 2, 1, 3)
|
v = v.transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
# Build attention mask from seq_lens
|
# Build attention mask from seq_lens
|
||||||
@@ -149,11 +153,14 @@ class WanCrossAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
b = context.shape[0]
|
b = context.shape[0]
|
||||||
n, d = self.num_heads, self.head_dim
|
n, d = self.num_heads, self.head_dim
|
||||||
k = self.k(context)
|
# Cast to weight dtype for efficient matmul
|
||||||
|
w_dtype = self.k.weight.dtype
|
||||||
|
ctx = context.astype(w_dtype)
|
||||||
|
k = self.k(ctx)
|
||||||
if self.norm_k is not None:
|
if self.norm_k is not None:
|
||||||
k = self.norm_k(k)
|
k = self.norm_k(k)
|
||||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -166,7 +173,9 @@ class WanCrossAttention(nn.Module):
|
|||||||
b = x.shape[0]
|
b = x.shape[0]
|
||||||
n, d = self.num_heads, self.head_dim
|
n, d = self.num_heads, self.head_dim
|
||||||
|
|
||||||
q = self.q(x)
|
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||||
|
w_dtype = self.q.weight.dtype
|
||||||
|
q = self.q(x.astype(w_dtype))
|
||||||
if self.norm_q is not None:
|
if self.norm_q is not None:
|
||||||
q = self.norm_q(q)
|
q = self.norm_q(q)
|
||||||
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
@@ -174,11 +183,12 @@ class WanCrossAttention(nn.Module):
|
|||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k, v = kv_cache
|
k, v = kv_cache
|
||||||
else:
|
else:
|
||||||
k = self.k(context)
|
ctx = context.astype(w_dtype)
|
||||||
|
k = self.k(ctx)
|
||||||
if self.norm_k is not None:
|
if self.norm_k is not None:
|
||||||
k = self.norm_k(k)
|
k = self.norm_k(k)
|
||||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
# Optional context masking
|
# Optional context masking
|
||||||
mask = None
|
mask = None
|
||||||
|
|||||||
@@ -90,3 +90,24 @@ class WanModelConfig(BaseModelConfig):
|
|||||||
def wan22_t2v_14b(cls) -> "WanModelConfig":
|
def wan22_t2v_14b(cls) -> "WanModelConfig":
|
||||||
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
||||||
|
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
|
||||||
|
return cls(
|
||||||
|
model_type="ti2v",
|
||||||
|
dim=3072,
|
||||||
|
ffn_dim=14336,
|
||||||
|
in_dim=48,
|
||||||
|
out_dim=48,
|
||||||
|
num_heads=24,
|
||||||
|
num_layers=30,
|
||||||
|
vae_z_dim=48,
|
||||||
|
vae_stride=(4, 16, 16),
|
||||||
|
dual_model=False,
|
||||||
|
boundary=0.0,
|
||||||
|
sample_shift=5.0,
|
||||||
|
sample_steps=50,
|
||||||
|
sample_guide_scale=5.0,
|
||||||
|
sample_fps=24,
|
||||||
|
)
|
||||||
|
|||||||
58
mlx_video/models/wan/i2v_utils.py
Normal file
58
mlx_video/models/wan/i2v_utils.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Image-to-Video utility functions for Wan2.2."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
|
||||||
|
"""Load, resize, center-crop, and normalize an image for I2V.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to input image
|
||||||
|
width: Target width
|
||||||
|
height: Target height
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
|
||||||
|
"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
img = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# Resize so that the image covers the target size (LANCZOS)
|
||||||
|
scale = max(width / img.width, height / img.height)
|
||||||
|
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||||
|
|
||||||
|
# Center crop
|
||||||
|
x1 = (img.width - width) // 2
|
||||||
|
y1 = (img.height - height) // 2
|
||||||
|
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||||
|
|
||||||
|
# To tensor: [H, W, 3] float32 in [-1, 1]
|
||||||
|
arr = np.array(img, dtype=np.float32) / 255.0
|
||||||
|
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
|
||||||
|
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def build_i2v_mask(z_shape, patch_size):
|
||||||
|
"""Build temporal mask for I2V: first frame = 0, rest = 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
z_shape: Latent shape (C, T, H, W) in channels-first
|
||||||
|
patch_size: (pt, ph, pw) patch size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
|
||||||
|
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
|
||||||
|
"""
|
||||||
|
C, T, H, W = z_shape
|
||||||
|
mask = mx.ones(z_shape)
|
||||||
|
# Zero out the first temporal position
|
||||||
|
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
|
||||||
|
|
||||||
|
# Token-level mask for per-token timesteps: subsample to patch grid
|
||||||
|
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
|
||||||
|
pt, ph, pw = patch_size
|
||||||
|
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
|
||||||
|
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
|
||||||
|
return mask, mask_tokens
|
||||||
154
mlx_video/models/wan/loading.py
Normal file
154
mlx_video/models/wan/loading.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""Wan model loading utilities."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||||
|
"""Load and initialize WanModel, with optional quantization support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model safetensors file
|
||||||
|
config: WanModelConfig
|
||||||
|
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||||
|
If provided, creates QuantizedLinear stubs before loading.
|
||||||
|
"""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
if quantization:
|
||||||
|
from mlx_video.convert_wan import _quantize_predicate
|
||||||
|
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=quantization["group_size"],
|
||||||
|
bits=quantization["bits"],
|
||||||
|
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
model.load_weights(list(weights.items()), strict=False)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_t5_encoder(model_path: Path, config):
|
||||||
|
"""Load T5 text encoder.
|
||||||
|
|
||||||
|
Weights are upcast to float32 for maximum precision — the T5 encoder
|
||||||
|
only runs once per generation, so performance impact is negligible.
|
||||||
|
This matches the official which computes softmax in float32 explicitly.
|
||||||
|
"""
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
|
||||||
|
encoder = T5Encoder(
|
||||||
|
vocab_size=config.t5_vocab_size,
|
||||||
|
dim=config.t5_dim,
|
||||||
|
dim_attn=config.t5_dim_attn,
|
||||||
|
dim_ffn=config.t5_dim_ffn,
|
||||||
|
num_heads=config.t5_num_heads,
|
||||||
|
num_layers=config.t5_num_layers,
|
||||||
|
num_buckets=config.t5_num_buckets,
|
||||||
|
shared_pos=False,
|
||||||
|
)
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||||
|
encoder.load_weights(list(weights.items()))
|
||||||
|
mx.eval(encoder.parameters())
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_decoder(model_path: Path, config=None):
|
||||||
|
"""Load VAE decoder (skips encoder weights with strict=False).
|
||||||
|
|
||||||
|
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
|
||||||
|
For Wan2.1 (vae_z_dim=16), uses WanVAE.
|
||||||
|
"""
|
||||||
|
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||||
|
|
||||||
|
if is_wan22:
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
vae = Wan22VAEDecoder(z_dim=48)
|
||||||
|
else:
|
||||||
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
vae = WanVAE(z_dim=16)
|
||||||
|
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
||||||
|
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||||
|
vae.load_weights(list(weights.items()), strict=False)
|
||||||
|
mx.eval(vae.parameters())
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_encoder(model_path: Path, config=None):
|
||||||
|
"""Load VAE encoder for I2V image encoding.
|
||||||
|
|
||||||
|
Only supports Wan2.2 (vae_z_dim=48).
|
||||||
|
"""
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||||
|
|
||||||
|
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim)
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||||
|
encoder.load_weights(list(weights.items()), strict=False)
|
||||||
|
mx.eval(encoder.parameters())
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_text(text: str) -> str:
|
||||||
|
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
||||||
|
|
||||||
|
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
||||||
|
double HTML unescape, and whitespace normalization. Critical for
|
||||||
|
correct tokenization of the Chinese negative prompt.
|
||||||
|
"""
|
||||||
|
import html
|
||||||
|
import re
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ftfy
|
||||||
|
text = ftfy.fix_text(text)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
text = html.unescape(html.unescape(text))
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def encode_text(
|
||||||
|
encoder,
|
||||||
|
tokenizer,
|
||||||
|
prompt: str,
|
||||||
|
text_len: int = 512,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Encode text prompt using T5 encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder: T5Encoder model
|
||||||
|
tokenizer: HuggingFace tokenizer
|
||||||
|
prompt: Text prompt
|
||||||
|
text_len: Maximum text length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text embeddings [L, dim]
|
||||||
|
"""
|
||||||
|
prompt = _clean_text(prompt)
|
||||||
|
tokens = tokenizer(
|
||||||
|
prompt,
|
||||||
|
max_length=text_len,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
ids = mx.array(tokens["input_ids"])
|
||||||
|
mask = mx.array(tokens["attention_mask"])
|
||||||
|
|
||||||
|
embeddings = encoder(ids, mask=mask)
|
||||||
|
|
||||||
|
# Return only non-padding tokens
|
||||||
|
seq_len = int(mask.sum().item())
|
||||||
|
return embeddings[0, :seq_len]
|
||||||
@@ -15,17 +15,17 @@ def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim: Embedding dimension (must be even).
|
dim: Embedding dimension (must be even).
|
||||||
position: 1D tensor of positions.
|
position: Tensor of positions — 1D [L] or 2D [B, L].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embeddings of shape [len(position), dim].
|
Embeddings of shape [L, dim] or [B, L, dim].
|
||||||
"""
|
"""
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
pos = position.astype(mx.float32)
|
pos = position.astype(mx.float32)
|
||||||
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
||||||
sinusoid = pos[:, None] * inv_freq[None, :]
|
sinusoid = pos[..., None] * inv_freq # [..., half]
|
||||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
class Head(nn.Module):
|
class Head(nn.Module):
|
||||||
@@ -44,16 +44,17 @@ class Head(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: [B, L, dim]
|
x: [B, L, dim]
|
||||||
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
|
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
|
||||||
"""
|
"""
|
||||||
if e.ndim == 2:
|
if e.ndim == 2:
|
||||||
e = e[:, None, :] # [B, 1, dim]
|
e = e[:, None, :] # [B, 1, dim]
|
||||||
e_f32 = e.astype(mx.float32)
|
e_f32 = e.astype(mx.float32)
|
||||||
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
|
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
|
||||||
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
|
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
|
||||||
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
|
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||||
|
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||||
x_norm = self.norm(x).astype(mx.float32)
|
x_norm = self.norm(x).astype(mx.float32)
|
||||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
|
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
|
||||||
return self.head(x_mod.astype(x.dtype))
|
return self.head(x_mod.astype(x.dtype))
|
||||||
|
|
||||||
|
|
||||||
@@ -261,18 +262,30 @@ class WanModel(nn.Module):
|
|||||||
axis=0,
|
axis=0,
|
||||||
) # [B, seq_len, dim]
|
) # [B, seq_len, dim]
|
||||||
|
|
||||||
# Time embedding: compute once per sample, then broadcast to all tokens
|
# Time embedding
|
||||||
if t.ndim == 0:
|
if t.ndim == 0:
|
||||||
t = t[None]
|
t = t[None]
|
||||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
|
||||||
|
|
||||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
if t.ndim == 1:
|
||||||
|
# Standard T2V: scalar timestep per batch element [B]
|
||||||
|
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||||
e = self.time_embedding_1(
|
e = self.time_embedding_1(
|
||||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||||
) # [B, dim]
|
) # [B, dim]
|
||||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
|
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||||
e = e.astype(model_dtype)
|
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
|
||||||
|
e = e.astype(mx.float32)
|
||||||
|
else:
|
||||||
|
# I2V: per-token timesteps [B, L]
|
||||||
|
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
|
||||||
|
e = self.time_embedding_1(
|
||||||
|
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||||
|
) # [B, L, dim]
|
||||||
|
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
||||||
|
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||||
|
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
|
||||||
|
e = e.astype(mx.float32)
|
||||||
|
|
||||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||||
if isinstance(context, mx.array):
|
if isinstance(context, mx.array):
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ class FlowUniPCScheduler:
|
|||||||
solver_order: int = 2,
|
solver_order: int = 2,
|
||||||
lower_order_final: bool = True,
|
lower_order_final: bool = True,
|
||||||
disable_corrector: list | None = None,
|
disable_corrector: list | None = None,
|
||||||
use_corrector: bool = False,
|
use_corrector: bool = True,
|
||||||
):
|
):
|
||||||
self.num_train_timesteps = num_train_timesteps
|
self.num_train_timesteps = num_train_timesteps
|
||||||
self.solver_order = solver_order
|
self.solver_order = solver_order
|
||||||
|
|||||||
@@ -49,9 +49,9 @@ class WanAttentionBlock(nn.Module):
|
|||||||
context_lens: list | None = None,
|
context_lens: list | None = None,
|
||||||
cross_kv_cache: tuple | None = None,
|
cross_kv_cache: tuple | None = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
|
# Modulation in float32 (matching official torch.amp.autocast float32)
|
||||||
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
|
e_f32 = e.astype(mx.float32)
|
||||||
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
|
mod = self.modulation.astype(mx.float32) + e_f32
|
||||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||||
@@ -59,19 +59,19 @@ class WanAttentionBlock(nn.Module):
|
|||||||
e4 = mod[:, :, 4, :] # scale for ffn
|
e4 = mod[:, :, 4, :] # scale for ffn
|
||||||
e5 = mod[:, :, 5, :] # gate for ffn
|
e5 = mod[:, :, 5, :] # gate for ffn
|
||||||
|
|
||||||
# Self-attention with modulation
|
# Self-attention with modulation (norm output in float32)
|
||||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
|
||||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||||
x = x + y * e2
|
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
|
||||||
|
|
||||||
# Cross-attention (no modulation, just norm)
|
# Cross-attention (no modulation, just norm)
|
||||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||||
|
|
||||||
# FFN with modulation
|
# FFN with modulation (norm output in float32)
|
||||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
|
||||||
y = self.ffn(x_mod)
|
y = self.ffn(x_mod)
|
||||||
x = x + y * e5
|
x = x + y.astype(mx.float32) * e5
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -86,4 +86,6 @@ class WanFFN(nn.Module):
|
|||||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
return self.fc2(self.act(self.fc1(x)))
|
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||||
|
x_w = x.astype(self.fc1.weight.dtype)
|
||||||
|
return self.fc2(self.act(self.fc1(x_w)))
|
||||||
|
|||||||
@@ -53,7 +53,9 @@ class CausalConv3d(nn.Module):
|
|||||||
|
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self._causal_pad_t = 2 * padding[0]
|
# Causal temporal padding: always kernel_size-1 on the left.
|
||||||
|
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
|
||||||
|
self._causal_pad_t = kernel_size[0] - 1
|
||||||
self._pad_h = padding[1]
|
self._pad_h = padding[1]
|
||||||
self._pad_w = padding[2]
|
self._pad_w = padding[2]
|
||||||
|
|
||||||
@@ -250,6 +252,46 @@ class DupUp3D(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AvgDown3D(nn.Module):
|
||||||
|
"""Downsample by grouping channels across spatial/temporal factors and averaging.
|
||||||
|
|
||||||
|
Inverse of DupUp3D. No learnable parameters.
|
||||||
|
Input: [B, T, H, W, C_in] → Output: [B, T//ft, H//fs, W//fs, C_out]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.factor_t = factor_t
|
||||||
|
self.factor_s = factor_s
|
||||||
|
self.factor = factor_t * factor_s * factor_s
|
||||||
|
assert in_channels * self.factor % out_channels == 0
|
||||||
|
self.group_size = in_channels * self.factor // out_channels
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# x: [B, T, H, W, C]
|
||||||
|
B, T, H, W, C = x.shape
|
||||||
|
|
||||||
|
# Pad temporal if not divisible by factor_t
|
||||||
|
pad_t = (self.factor_t - T % self.factor_t) % self.factor_t
|
||||||
|
if pad_t > 0:
|
||||||
|
x = mx.pad(x, [(0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)])
|
||||||
|
T = T + pad_t
|
||||||
|
|
||||||
|
ft, fs = self.factor_t, self.factor_s
|
||||||
|
# Reshape to split spatial/temporal dims
|
||||||
|
x = x.reshape(B, T // ft, ft, H // fs, fs, W // fs, fs, C)
|
||||||
|
# Move factors next to channels
|
||||||
|
x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, T', H', W', C, ft, fs, fs]
|
||||||
|
# Expand channels
|
||||||
|
x = x.reshape(B, T // ft, H // fs, W // fs, C * self.factor)
|
||||||
|
# Group and average
|
||||||
|
x = x.reshape(B, T // ft, H // fs, W // fs, self.out_channels, self.group_size)
|
||||||
|
x = x.mean(axis=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Resample(nn.Module):
|
class Resample(nn.Module):
|
||||||
"""Spatial up/downsampling with optional temporal up/downsampling."""
|
"""Spatial up/downsampling with optional temporal up/downsampling."""
|
||||||
|
|
||||||
@@ -267,6 +309,15 @@ class Resample(nn.Module):
|
|||||||
self.resample_bias = mx.zeros((dim,))
|
self.resample_bias = mx.zeros((dim,))
|
||||||
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
|
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
|
||||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
elif mode == "downsample2d":
|
||||||
|
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||||
|
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||||
|
self.resample_bias = mx.zeros((dim,))
|
||||||
|
elif mode == "downsample3d":
|
||||||
|
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||||
|
self.resample_bias = mx.zeros((dim,))
|
||||||
|
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
|
||||||
|
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported mode: {mode}")
|
raise ValueError(f"Unsupported mode: {mode}")
|
||||||
|
|
||||||
@@ -283,6 +334,12 @@ class Resample(nn.Module):
|
|||||||
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||||
return mx.conv_general(x, self.resample_weight) + self.resample_bias
|
return mx.conv_general(x, self.resample_weight) + self.resample_bias
|
||||||
|
|
||||||
|
def _downsample_conv2d(self, x):
|
||||||
|
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
|
||||||
|
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
|
||||||
|
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||||
|
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||||
|
|
||||||
def __call__(self, x, first_chunk=False):
|
def __call__(self, x, first_chunk=False):
|
||||||
# x: [B, T, H, W, C]
|
# x: [B, T, H, W, C]
|
||||||
B, T, H, W, C = x.shape
|
B, T, H, W, C = x.shape
|
||||||
@@ -320,6 +377,15 @@ class Resample(nn.Module):
|
|||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
T = x.shape[1]
|
T = x.shape[1]
|
||||||
|
|
||||||
|
if self.mode == "downsample3d" and T > 1:
|
||||||
|
# Temporal downsample via strided CausalConv3d
|
||||||
|
# Skip for T=1 (single frame) — matches official chunked encoding
|
||||||
|
# where first chunk stores cache but doesn't apply time_conv
|
||||||
|
x = self.time_conv(x)
|
||||||
|
mx.eval(x)
|
||||||
|
T = x.shape[1]
|
||||||
|
|
||||||
|
if self.mode in ("upsample2d", "upsample3d"):
|
||||||
# Spatial upsample in temporal chunks to limit peak memory
|
# Spatial upsample in temporal chunks to limit peak memory
|
||||||
chunk_size = 8
|
chunk_size = 8
|
||||||
chunks = []
|
chunks = []
|
||||||
@@ -334,6 +400,14 @@ class Resample(nn.Module):
|
|||||||
x = mx.concatenate(chunks, axis=0)
|
x = mx.concatenate(chunks, axis=0)
|
||||||
H2, W2 = x.shape[1], x.shape[2]
|
H2, W2 = x.shape[1], x.shape[2]
|
||||||
x = x.reshape(B, T, H2, W2, C)
|
x = x.reshape(B, T, H2, W2, C)
|
||||||
|
elif self.mode in ("downsample2d", "downsample3d"):
|
||||||
|
# Spatial downsample: per-frame strided Conv2d
|
||||||
|
x_flat = x.reshape(B * T, H, W, C)
|
||||||
|
x_flat = self._downsample_conv2d(x_flat)
|
||||||
|
mx.eval(x_flat)
|
||||||
|
H2, W2 = x_flat.shape[1], x_flat.shape[2]
|
||||||
|
x = x_flat.reshape(B, T, H2, W2, C)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -383,6 +457,44 @@ class Up_ResidualBlock(nn.Module):
|
|||||||
return x_main
|
return x_main
|
||||||
|
|
||||||
|
|
||||||
|
class Down_ResidualBlock(nn.Module):
|
||||||
|
"""Downsampling residual block with AvgDown3D shortcut."""
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
|
||||||
|
super().__init__()
|
||||||
|
self.down_flag = down_flag
|
||||||
|
|
||||||
|
# AvgDown3D shortcut (no learnable params, always present)
|
||||||
|
self.avg_shortcut = AvgDown3D(
|
||||||
|
in_dim, out_dim,
|
||||||
|
factor_t=2 if temperal_downsample else 1,
|
||||||
|
factor_s=2 if down_flag else 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main path: ResidualBlocks + optional Resample
|
||||||
|
blocks = []
|
||||||
|
dim_in = in_dim
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
blocks.append(ResidualBlock(dim_in, out_dim))
|
||||||
|
dim_in = out_dim
|
||||||
|
|
||||||
|
if down_flag:
|
||||||
|
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||||
|
blocks.append(Resample(out_dim, mode=mode))
|
||||||
|
|
||||||
|
self.downsamples = blocks
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x_shortcut = self.avg_shortcut(x)
|
||||||
|
mx.eval(x_shortcut)
|
||||||
|
|
||||||
|
for module in self.downsamples:
|
||||||
|
x = module(x)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
return x + x_shortcut
|
||||||
|
|
||||||
|
|
||||||
class Decoder3d(nn.Module):
|
class Decoder3d(nn.Module):
|
||||||
"""Wan2.2 3D VAE Decoder."""
|
"""Wan2.2 3D VAE Decoder."""
|
||||||
|
|
||||||
@@ -439,6 +551,63 @@ class Decoder3d(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder3d(nn.Module):
|
||||||
|
"""Wan2.2 3D VAE Encoder. Mirror of Decoder3d with downsampling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim=160,
|
||||||
|
z_dim=96,
|
||||||
|
dim_mult=(1, 2, 4, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
temperal_downsample=(False, True, True),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Channel dimensions: [160, 160, 320, 640, 640]
|
||||||
|
dims = [dim * m for m in [1] + list(dim_mult)]
|
||||||
|
|
||||||
|
# Initial conv: patchified input (12 ch) → first dim
|
||||||
|
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# Downsample blocks
|
||||||
|
self.downsamples = []
|
||||||
|
for i in range(len(dim_mult)):
|
||||||
|
in_d, out_d = dims[i], dims[i + 1]
|
||||||
|
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||||
|
self.downsamples.append(Down_ResidualBlock(
|
||||||
|
in_dim=in_d,
|
||||||
|
out_dim=out_d,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
temperal_downsample=t_down,
|
||||||
|
down_flag=(i < len(dim_mult) - 1),
|
||||||
|
))
|
||||||
|
|
||||||
|
# Middle blocks (same as decoder)
|
||||||
|
out_dim = dims[-1]
|
||||||
|
self.middle = [
|
||||||
|
ResidualBlock(out_dim, out_dim),
|
||||||
|
AttentionBlock(out_dim),
|
||||||
|
ResidualBlock(out_dim, out_dim),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
|
||||||
|
self.head = Head22(out_dim, out_channels=z_dim)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# x: [B, T, H, W, 12] (patchified)
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
for layer in self.downsamples:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
for layer in self.middle:
|
||||||
|
x = layer(x)
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Head22(nn.Module):
|
class Head22(nn.Module):
|
||||||
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
|
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
|
||||||
|
|
||||||
@@ -460,6 +629,46 @@ class Head22(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Wan22VAEEncoder(nn.Module):
|
||||||
|
"""Full Wan2.2 VAE encoder with patchify and normalization."""
|
||||||
|
|
||||||
|
def __init__(self, z_dim=48, dim=160):
|
||||||
|
super().__init__()
|
||||||
|
self.z_dim = z_dim
|
||||||
|
# conv1: top-level 1x1x1 conv after encoder (z_dim*2 → z_dim*2)
|
||||||
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
|
self.encoder = Encoder3d(
|
||||||
|
dim=dim,
|
||||||
|
z_dim=z_dim * 2, # Encoder outputs z_dim*2, split into mu + log_var
|
||||||
|
dim_mult=(1, 2, 4, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
temperal_downsample=(False, True, True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""Encode image/video to latent space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: [B, T, H, W, 3] image/video in [-1, 1]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
||||||
|
"""
|
||||||
|
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
|
||||||
|
x = _patchify(img, patch_size=2)
|
||||||
|
|
||||||
|
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
|
||||||
|
out = self.encoder(x)
|
||||||
|
|
||||||
|
# conv1 (pointwise) + split into mu, log_var
|
||||||
|
out = self.conv1(out)
|
||||||
|
mu = out[:, :, :, :, :self.z_dim]
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
mu = normalize_latents(mu)
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
class Wan22VAEDecoder(nn.Module):
|
class Wan22VAEDecoder(nn.Module):
|
||||||
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
||||||
|
|
||||||
@@ -507,6 +716,15 @@ def denormalize_latents(z, mean=None, std=None):
|
|||||||
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
|
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_latents(z, mean=None, std=None):
|
||||||
|
"""Normalize latents: z_norm = (z - mean) / std. Inverse of denormalize_latents."""
|
||||||
|
if mean is None:
|
||||||
|
mean = VAE22_MEAN
|
||||||
|
if std is None:
|
||||||
|
std = VAE22_STD
|
||||||
|
return (z - mean.reshape(1, 1, 1, 1, -1)) / std.reshape(1, 1, 1, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
def _unpatchify(x, patch_size=2):
|
def _unpatchify(x, patch_size=2):
|
||||||
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
|
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
|
||||||
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
|
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
|
||||||
@@ -527,10 +745,30 @@ def _unpatchify(x, patch_size=2):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
def _patchify(x, patch_size=2):
|
||||||
|
"""Convert spatial to packed channels: [B, T, H*p, W*p, C] → [B, T, H, W, C*p*p]
|
||||||
|
Inverse of _unpatchify.
|
||||||
|
PyTorch: b c f (h q) (w r) -> b (c r q) f h w
|
||||||
|
In channels-last: [B, T, H*q, W*r, C] → [B, T, H, W, C*r*q]
|
||||||
|
"""
|
||||||
|
if patch_size == 1:
|
||||||
|
return x
|
||||||
|
B, T, Hfull, Wfull, C = x.shape
|
||||||
|
H = Hfull // patch_size
|
||||||
|
W = Wfull // patch_size
|
||||||
|
# [B, T, H, q, W, r, C]
|
||||||
|
x = x.reshape(B, T, H, patch_size, W, patch_size, C)
|
||||||
|
# Rearrange to pack q,r into channels: [B, T, H, W, C, r, q]
|
||||||
|
x = x.transpose(0, 1, 2, 4, 6, 5, 3) # [B, T, H, W, C, r, q]
|
||||||
|
x = x.reshape(B, T, H, W, C * patch_size * patch_size)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> dict:
|
||||||
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
|
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
|
||||||
|
|
||||||
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
|
By default keeps decoder + conv2 weights only. Set include_encoder=True
|
||||||
|
to also keep encoder + conv1 weights (needed for I2V encoding).
|
||||||
Transposes conv weights from channels-first to channels-last.
|
Transposes conv weights from channels-first to channels-last.
|
||||||
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
|
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
|
||||||
Maps PyTorch nn.Sequential indices to our named layers.
|
Maps PyTorch nn.Sequential indices to our named layers.
|
||||||
@@ -538,7 +776,8 @@ def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
|||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
# Skip encoder and conv1 (encoder-only)
|
# Skip encoder and conv1 unless requested
|
||||||
|
if not include_encoder:
|
||||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,42 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||||
|
"""Save video frames to MP4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames: Video frames [T, H, W, 3] uint8
|
||||||
|
output_path: Output file path
|
||||||
|
fps: Frames per second
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import imageio
|
||||||
|
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||||
|
for frame in frames:
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
h, w = frames.shape[1], frames.shape[2]
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||||
|
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||||
|
for frame in frames:
|
||||||
|
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
|
writer.release()
|
||||||
|
except (ImportError, Exception):
|
||||||
|
# Last resort: save as individual PNGs
|
||||||
|
from PIL import Image
|
||||||
|
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for i, frame in enumerate(frames):
|
||||||
|
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||||
|
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
||||||
|
|
||||||
|
|
||||||
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
||||||
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,20 @@ from pathlib import Path
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class Colors:
|
||||||
|
"""ANSI color codes for terminal output."""
|
||||||
|
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RED = "\033[91m"
|
||||||
|
MAGENTA = "\033[95m"
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
DIM = "\033[2m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
def get_model_path(model_repo: str):
|
def get_model_path(model_repo: str):
|
||||||
"""Get or download LTX-2 model path."""
|
"""Get or download LTX-2 model path."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
4
tests/conftest.py
Normal file
4
tests/conftest.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
2712
tests/test_wan.py
2712
tests/test_wan.py
File diff suppressed because it is too large
Load Diff
372
tests/test_wan_attention.py
Normal file
372
tests/test_wan_attention.py
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
"""Tests for Wan attention components and RoPE."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RoPE Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRoPE:
|
||||||
|
"""Tests for 3-way factorized RoPE."""
|
||||||
|
|
||||||
|
def test_rope_params_shape(self):
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
freqs = rope_params(1024, 64)
|
||||||
|
mx.eval(freqs)
|
||||||
|
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
||||||
|
|
||||||
|
def test_rope_params_different_dims(self):
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
for dim in [32, 64, 128]:
|
||||||
|
freqs = rope_params(512, dim)
|
||||||
|
mx.eval(freqs)
|
||||||
|
assert freqs.shape == (512, dim // 2, 2)
|
||||||
|
|
||||||
|
def test_rope_params_cos_sin_range(self):
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
freqs = rope_params(256, 64)
|
||||||
|
mx.eval(freqs)
|
||||||
|
cos_vals = np.array(freqs[:, :, 0])
|
||||||
|
sin_vals = np.array(freqs[:, :, 1])
|
||||||
|
assert np.all(cos_vals >= -1.0) and np.all(cos_vals <= 1.0)
|
||||||
|
assert np.all(sin_vals >= -1.0) and np.all(sin_vals <= 1.0)
|
||||||
|
|
||||||
|
def test_rope_params_position_zero(self):
|
||||||
|
"""At position 0, cos should be 1 and sin should be 0."""
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
freqs = rope_params(10, 64)
|
||||||
|
mx.eval(freqs)
|
||||||
|
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
||||||
|
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
||||||
|
|
||||||
|
def test_rope_apply_output_shape(self):
|
||||||
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||||
|
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
||||||
|
x = mx.random.normal((B, L, N, D))
|
||||||
|
freqs = rope_params(1024, D)
|
||||||
|
grid_sizes = [(2, 3, 4)] # F*H*W = 24 = L
|
||||||
|
out = rope_apply(x, grid_sizes, freqs)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (B, L, N, D)
|
||||||
|
|
||||||
|
def test_rope_apply_preserves_norm(self):
|
||||||
|
"""RoPE rotation should preserve vector norms."""
|
||||||
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||||
|
B, N, D = 1, 2, 16
|
||||||
|
F, H, W = 2, 3, 4
|
||||||
|
L = F * H * W
|
||||||
|
x = mx.random.normal((B, L, N, D))
|
||||||
|
freqs = rope_params(1024, D)
|
||||||
|
|
||||||
|
out = rope_apply(x, [(F, H, W)], freqs)
|
||||||
|
mx.eval(x, out)
|
||||||
|
|
||||||
|
x_np = np.array(x[0])
|
||||||
|
out_np = np.array(out[0])
|
||||||
|
for i in range(L):
|
||||||
|
for h in range(N):
|
||||||
|
norm_in = np.linalg.norm(x_np[i, h])
|
||||||
|
norm_out = np.linalg.norm(out_np[i, h])
|
||||||
|
np.testing.assert_allclose(norm_in, norm_out, rtol=1e-4)
|
||||||
|
|
||||||
|
def test_rope_apply_with_padding(self):
|
||||||
|
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
||||||
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||||
|
B, N, D = 1, 2, 16
|
||||||
|
F, H, W = 2, 2, 2
|
||||||
|
seq_len = F * H * W # 8
|
||||||
|
pad = 4
|
||||||
|
L = seq_len + pad
|
||||||
|
x = mx.random.normal((B, L, N, D))
|
||||||
|
freqs = rope_params(1024, D)
|
||||||
|
|
||||||
|
out = rope_apply(x, [(F, H, W)], freqs)
|
||||||
|
mx.eval(x, out)
|
||||||
|
# Padded tokens should be unchanged
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(out[0, seq_len:]),
|
||||||
|
np.array(x[0, seq_len:]),
|
||||||
|
atol=1e-6,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rope_apply_batch(self):
|
||||||
|
"""Test with batch_size > 1 and different grid sizes."""
|
||||||
|
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||||
|
B, N, D = 2, 2, 16
|
||||||
|
grids = [(2, 3, 4), (2, 3, 4)]
|
||||||
|
L = 2 * 3 * 4
|
||||||
|
x = mx.random.normal((B, L, N, D))
|
||||||
|
freqs = rope_params(1024, D)
|
||||||
|
|
||||||
|
out = rope_apply(x, grids, freqs)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (B, L, N, D)
|
||||||
|
|
||||||
|
def test_rope_frequency_split(self):
|
||||||
|
"""Verify the 3-way frequency dimension split matches Wan2.2 convention."""
|
||||||
|
D = 128 # head_dim for 14B model
|
||||||
|
half_d = D // 2
|
||||||
|
d_t = half_d - 2 * (half_d // 3)
|
||||||
|
d_h = half_d // 3
|
||||||
|
d_w = half_d // 3
|
||||||
|
assert d_t + d_h + d_w == half_d
|
||||||
|
# Temporal gets more capacity
|
||||||
|
assert d_t >= d_h
|
||||||
|
assert d_t >= d_w
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Attention Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWanRMSNorm:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.attention import WanRMSNorm
|
||||||
|
norm = WanRMSNorm(64)
|
||||||
|
x = mx.random.normal((2, 10, 64))
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (2, 10, 64)
|
||||||
|
|
||||||
|
def test_zero_mean_variance(self):
|
||||||
|
"""RMS norm should make RMS ≈ 1 before scaling."""
|
||||||
|
from mlx_video.models.wan.attention import WanRMSNorm
|
||||||
|
norm = WanRMSNorm(64)
|
||||||
|
x = mx.random.normal((1, 5, 64)) * 10.0
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
out_np = np.array(out[0])
|
||||||
|
for i in range(5):
|
||||||
|
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
||||||
|
# After RMS norm with weight=1, RMS should be ~1
|
||||||
|
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
||||||
|
|
||||||
|
def test_dtype_preservation(self):
|
||||||
|
"""RMSNorm weight is float32, so output is promoted to float32."""
|
||||||
|
from mlx_video.models.wan.attention import WanRMSNorm
|
||||||
|
norm = WanRMSNorm(32)
|
||||||
|
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# Weight is float32, so multiplication promotes result to float32
|
||||||
|
assert out.dtype == mx.float32
|
||||||
|
|
||||||
|
|
||||||
|
class TestWanLayerNorm:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.attention import WanLayerNorm
|
||||||
|
norm = WanLayerNorm(64)
|
||||||
|
x = mx.random.normal((2, 10, 64))
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (2, 10, 64)
|
||||||
|
|
||||||
|
def test_without_affine(self):
|
||||||
|
from mlx_video.models.wan.attention import WanLayerNorm
|
||||||
|
norm = WanLayerNorm(64, elementwise_affine=False)
|
||||||
|
x = mx.random.normal((1, 4, 64))
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# Mean should be ~0, variance should be ~1
|
||||||
|
out_np = np.array(out[0])
|
||||||
|
for i in range(4):
|
||||||
|
np.testing.assert_allclose(np.mean(out_np[i]), 0.0, atol=0.05)
|
||||||
|
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
|
||||||
|
|
||||||
|
def test_with_affine(self):
|
||||||
|
from mlx_video.models.wan.attention import WanLayerNorm
|
||||||
|
norm = WanLayerNorm(32, elementwise_affine=True)
|
||||||
|
assert hasattr(norm, "weight")
|
||||||
|
assert hasattr(norm, "bias")
|
||||||
|
x = mx.random.normal((1, 4, 32))
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 4, 32)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWanSelfAttention:
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
self.dim = 64
|
||||||
|
self.num_heads = 4
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||||
|
B, L = 1, 24
|
||||||
|
F, H, W = 2, 3, 4
|
||||||
|
x = mx.random.normal((B, L, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
out = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (B, L, self.dim)
|
||||||
|
|
||||||
|
def test_with_qk_norm(self):
|
||||||
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
||||||
|
assert attn.norm_q is not None
|
||||||
|
assert attn.norm_k is not None
|
||||||
|
|
||||||
|
def test_without_qk_norm(self):
|
||||||
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||||
|
assert attn.norm_q is None
|
||||||
|
assert attn.norm_k is None
|
||||||
|
|
||||||
|
def test_masking(self):
|
||||||
|
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
||||||
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||||
|
B, L = 1, 24
|
||||||
|
F, H, W = 2, 3, 4
|
||||||
|
x = mx.random.normal((B, L, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
|
||||||
|
# Full sequence
|
||||||
|
out_full = attn(x, seq_lens=[L], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||||
|
# Shorter sequence (mask last 4 tokens)
|
||||||
|
out_masked = attn(x, seq_lens=[L - 4], grid_sizes=[(F, H, W)], freqs=freqs)
|
||||||
|
mx.eval(out_full, out_masked)
|
||||||
|
|
||||||
|
# Outputs should differ when masking is applied
|
||||||
|
assert not np.allclose(np.array(out_full), np.array(out_masked), atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWanCrossAttention:
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
self.dim = 64
|
||||||
|
self.num_heads = 4
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
|
B, L_q, L_kv = 1, 24, 16
|
||||||
|
x = mx.random.normal((B, L_q, self.dim))
|
||||||
|
context = mx.random.normal((B, L_kv, self.dim))
|
||||||
|
out = attn(x, context)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (B, L_q, self.dim)
|
||||||
|
|
||||||
|
def test_with_context_mask(self):
|
||||||
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
|
B, L_q, L_kv = 1, 12, 16
|
||||||
|
x = mx.random.normal((B, L_q, self.dim))
|
||||||
|
context = mx.random.normal((B, L_kv, self.dim))
|
||||||
|
out = attn(x, context, context_lens=[10])
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (B, L_q, self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# bfloat16 Autocast Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBFloat16Autocast:
|
||||||
|
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
|
||||||
|
for efficient matmul, matching official PyTorch autocast behavior."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
self.dim = 64
|
||||||
|
self.num_heads = 4
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_bf16(params):
|
||||||
|
"""Recursively cast all arrays in params to bfloat16."""
|
||||||
|
if isinstance(params, dict):
|
||||||
|
return {k: TestBFloat16Autocast._to_bf16(v) for k, v in params.items()}
|
||||||
|
elif isinstance(params, list):
|
||||||
|
return [TestBFloat16Autocast._to_bf16(v) for v in params]
|
||||||
|
elif isinstance(params, mx.array):
|
||||||
|
return params.astype(mx.bfloat16)
|
||||||
|
return params
|
||||||
|
|
||||||
|
def test_self_attn_casts_to_weight_dtype(self):
|
||||||
|
"""Self-attention should cast input to weight dtype for QKV projections."""
|
||||||
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||||
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 8, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 8, self.dim)
|
||||||
|
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||||
|
|
||||||
|
def test_cross_attn_casts_to_weight_dtype(self):
|
||||||
|
"""Cross-attention should cast input to weight dtype."""
|
||||||
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 8, self.dim))
|
||||||
|
ctx = mx.random.normal((1, 4, self.dim))
|
||||||
|
out = attn(x, ctx)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 8, self.dim)
|
||||||
|
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||||
|
|
||||||
|
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
||||||
|
"""prepare_kv should cast context to weight dtype."""
|
||||||
|
from mlx_video.models.wan.attention import WanCrossAttention
|
||||||
|
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||||
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
|
ctx = mx.random.normal((1, 4, self.dim))
|
||||||
|
k, v = attn.prepare_kv(ctx)
|
||||||
|
mx.eval(k, v)
|
||||||
|
assert k.dtype == mx.bfloat16
|
||||||
|
assert v.dtype == mx.bfloat16
|
||||||
|
|
||||||
|
def test_ffn_casts_to_weight_dtype(self):
|
||||||
|
"""FFN should cast input to weight dtype for linear layers."""
|
||||||
|
from mlx_video.models.wan.transformer import WanFFN
|
||||||
|
ffn = WanFFN(self.dim, 128)
|
||||||
|
ffn.update(self._to_bf16(ffn.parameters()))
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 8, self.dim))
|
||||||
|
out = ffn(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 8, self.dim)
|
||||||
|
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||||
|
|
||||||
|
def test_self_attn_rope_in_float32(self):
|
||||||
|
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
||||||
|
from mlx_video.models.wan.attention import WanSelfAttention
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||||
|
attn.update(self._to_bf16(attn.parameters()))
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 8, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
assert freqs.dtype == mx.float32
|
||||||
|
out = attn(x, seq_lens=[8], grid_sizes=[(2, 2, 2)], freqs=freqs)
|
||||||
|
mx.eval(out)
|
||||||
|
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||||
|
|
||||||
|
def test_block_float32_residual_with_bf16_weights(self):
|
||||||
|
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
||||||
|
block.update(self._to_bf16(block.parameters()))
|
||||||
|
|
||||||
|
B, L = 1, 8
|
||||||
|
x = mx.random.normal((B, L, self.dim))
|
||||||
|
e = mx.random.normal((B, L, 6, self.dim))
|
||||||
|
ctx = mx.random.normal((B, 4, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
|
||||||
|
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.dtype == mx.float32
|
||||||
|
assert np.isfinite(np.array(out)).all()
|
||||||
125
tests/test_wan_config.py
Normal file
125
tests/test_wan_config.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Tests for Wan model configuration."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWanModelConfig:
|
||||||
|
"""Tests for WanModelConfig dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig()
|
||||||
|
assert config.dim == 5120
|
||||||
|
assert config.ffn_dim == 13824
|
||||||
|
assert config.num_heads == 40
|
||||||
|
assert config.num_layers == 40
|
||||||
|
assert config.in_dim == 16
|
||||||
|
assert config.out_dim == 16
|
||||||
|
assert config.patch_size == (1, 2, 2)
|
||||||
|
assert config.vae_stride == (4, 8, 8)
|
||||||
|
assert config.vae_z_dim == 16
|
||||||
|
assert config.boundary == 0.875
|
||||||
|
assert config.sample_shift == 12.0
|
||||||
|
assert config.sample_steps == 40
|
||||||
|
assert config.sample_guide_scale == (3.0, 4.0)
|
||||||
|
assert config.num_train_timesteps == 1000
|
||||||
|
assert config.qk_norm is True
|
||||||
|
assert config.cross_attn_norm is True
|
||||||
|
assert config.text_len == 512
|
||||||
|
|
||||||
|
def test_head_dim_property(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig()
|
||||||
|
assert config.head_dim == 128 # 5120 // 40
|
||||||
|
|
||||||
|
def test_to_dict_roundtrip(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig()
|
||||||
|
d = config.to_dict()
|
||||||
|
assert isinstance(d, dict)
|
||||||
|
assert d["dim"] == 5120
|
||||||
|
assert d["patch_size"] == (1, 2, 2)
|
||||||
|
assert d["boundary"] == 0.875
|
||||||
|
|
||||||
|
def test_t5_config_values(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig()
|
||||||
|
assert config.t5_vocab_size == 256384
|
||||||
|
assert config.t5_dim == 4096
|
||||||
|
assert config.t5_dim_attn == 4096
|
||||||
|
assert config.t5_dim_ffn == 10240
|
||||||
|
assert config.t5_num_heads == 64
|
||||||
|
assert config.t5_num_layers == 24
|
||||||
|
assert config.t5_num_buckets == 32
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Wan2.1 Config Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWan21Config:
|
||||||
|
"""Tests for Wan2.1 config presets."""
|
||||||
|
|
||||||
|
def test_wan21_14b_factory(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
assert config.model_version == "2.1"
|
||||||
|
assert config.dual_model is False
|
||||||
|
assert config.dim == 5120
|
||||||
|
assert config.ffn_dim == 13824
|
||||||
|
assert config.num_heads == 40
|
||||||
|
assert config.num_layers == 40
|
||||||
|
assert config.head_dim == 128
|
||||||
|
assert config.sample_guide_scale == 5.0
|
||||||
|
assert config.sample_shift == 5.0
|
||||||
|
assert config.sample_steps == 50
|
||||||
|
assert config.boundary == 0.0
|
||||||
|
|
||||||
|
def test_wan21_1_3b_factory(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
|
assert config.model_version == "2.1"
|
||||||
|
assert config.dual_model is False
|
||||||
|
assert config.dim == 1536
|
||||||
|
assert config.ffn_dim == 8960
|
||||||
|
assert config.num_heads == 12
|
||||||
|
assert config.num_layers == 30
|
||||||
|
assert config.head_dim == 128 # 1536 // 12
|
||||||
|
assert config.sample_guide_scale == 5.0
|
||||||
|
|
||||||
|
def test_wan22_14b_factory(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan22_t2v_14b()
|
||||||
|
assert config.model_version == "2.2"
|
||||||
|
assert config.dual_model is True
|
||||||
|
assert config.dim == 5120
|
||||||
|
assert config.sample_guide_scale == (3.0, 4.0)
|
||||||
|
assert config.sample_shift == 12.0
|
||||||
|
assert config.sample_steps == 40
|
||||||
|
assert config.boundary == 0.875
|
||||||
|
|
||||||
|
def test_wan21_config_to_dict(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
d = config.to_dict()
|
||||||
|
assert d["model_version"] == "2.1"
|
||||||
|
assert d["dual_model"] is False
|
||||||
|
assert d["sample_guide_scale"] == 5.0
|
||||||
|
|
||||||
|
def test_wan21_1_3b_config_to_dict(self):
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
|
d = config.to_dict()
|
||||||
|
assert d["dim"] == 1536
|
||||||
|
assert d["num_layers"] == 30
|
||||||
|
|
||||||
|
def test_default_config_is_wan22(self):
|
||||||
|
"""Default WanModelConfig() should be Wan2.2 14B."""
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig()
|
||||||
|
assert config.model_version == "2.2"
|
||||||
|
assert config.dual_model is True
|
||||||
235
tests/test_wan_convert.py
Normal file
235
tests/test_wan_convert.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""Tests for Wan weight conversion utilities."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Transformer Weight Conversion Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSanitizeTransformerWeights:
|
||||||
|
def test_patch_embedding_reshape(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||||
|
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_transformer_weights(weights)
|
||||||
|
assert "patch_embedding_proj.weight" in out
|
||||||
|
assert "patch_embedding_proj.bias" in out
|
||||||
|
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
|
||||||
|
|
||||||
|
def test_text_embedding_rename(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||||
|
"text_embedding.0.bias": mx.zeros((64,)),
|
||||||
|
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||||
|
"text_embedding.2.bias": mx.zeros((64,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_transformer_weights(weights)
|
||||||
|
assert "text_embedding_0.weight" in out
|
||||||
|
assert "text_embedding_0.bias" in out
|
||||||
|
assert "text_embedding_1.weight" in out
|
||||||
|
assert "text_embedding_1.bias" in out
|
||||||
|
|
||||||
|
def test_time_embedding_rename(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||||
|
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_transformer_weights(weights)
|
||||||
|
assert "time_embedding_0.weight" in out
|
||||||
|
assert "time_embedding_1.weight" in out
|
||||||
|
|
||||||
|
def test_time_projection_rename(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||||
|
"time_projection.1.bias": mx.zeros((384,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_transformer_weights(weights)
|
||||||
|
assert "time_projection.weight" in out
|
||||||
|
assert "time_projection.bias" in out
|
||||||
|
|
||||||
|
def test_ffn_rename(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||||
|
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
||||||
|
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||||
|
"blocks.0.ffn.2.bias": mx.zeros((64,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_transformer_weights(weights)
|
||||||
|
assert "blocks.0.ffn.fc1.weight" in out
|
||||||
|
assert "blocks.0.ffn.fc1.bias" in out
|
||||||
|
assert "blocks.0.ffn.fc2.weight" in out
|
||||||
|
assert "blocks.0.ffn.fc2.bias" in out
|
||||||
|
|
||||||
|
def test_freqs_skipped(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"freqs": mx.zeros((1024, 64, 2)),
|
||||||
|
"blocks.0.norm1.weight": mx.zeros((64,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_transformer_weights(weights)
|
||||||
|
assert "freqs" not in out
|
||||||
|
assert "blocks.0.norm1.weight" in out
|
||||||
|
|
||||||
|
def test_passthrough_keys(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||||
|
weights = {
|
||||||
|
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||||
|
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
||||||
|
"blocks.0.self_attn.v.weight": mx.zeros((64, 64)),
|
||||||
|
"blocks.0.self_attn.o.weight": mx.zeros((64, 64)),
|
||||||
|
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||||
|
"head.head.weight": mx.zeros((64, 64)),
|
||||||
|
"head.modulation": mx.zeros((1, 2, 64)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_transformer_weights(weights)
|
||||||
|
for key in weights:
|
||||||
|
assert key in out
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeT5Weights:
|
||||||
|
def test_gate_rename(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||||
|
weights = {
|
||||||
|
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||||
|
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||||
|
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_t5_weights(weights)
|
||||||
|
assert "blocks.0.ffn.gate_proj.weight" in out
|
||||||
|
assert "blocks.0.ffn.fc1.weight" in out
|
||||||
|
assert "blocks.0.ffn.fc2.weight" in out
|
||||||
|
|
||||||
|
def test_passthrough(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||||
|
weights = {
|
||||||
|
"token_embedding.weight": mx.zeros((100, 64)),
|
||||||
|
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
||||||
|
"norm.weight": mx.zeros((64,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_t5_weights(weights)
|
||||||
|
for key in weights:
|
||||||
|
assert key in out
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeVAEWeights:
|
||||||
|
def test_conv3d_transpose(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
weights = {
|
||||||
|
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
||||||
|
}
|
||||||
|
out = sanitize_wan_vae_weights(weights)
|
||||||
|
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
|
||||||
|
|
||||||
|
def test_conv2d_transpose(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
weights = {
|
||||||
|
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
||||||
|
}
|
||||||
|
out = sanitize_wan_vae_weights(weights)
|
||||||
|
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
|
||||||
|
|
||||||
|
def test_non_conv_passthrough(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
weights = {
|
||||||
|
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
||||||
|
"decoder.bias": mx.zeros((16,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan_vae_weights(weights)
|
||||||
|
assert out["decoder.norm.weight"].shape == (64,)
|
||||||
|
assert out["decoder.bias"].shape == (16,)
|
||||||
|
|
||||||
|
def test_mixed_weights(self):
|
||||||
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||||
|
weights = {
|
||||||
|
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
||||||
|
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
||||||
|
"linear.weight": mx.zeros((8, 4)), # 2D
|
||||||
|
"norm.weight": mx.zeros((8,)), # 1D
|
||||||
|
}
|
||||||
|
out = sanitize_wan_vae_weights(weights)
|
||||||
|
assert out["conv3d.weight"].shape == (8, 3, 3, 3, 4)
|
||||||
|
assert out["conv2d.weight"].shape == (8, 3, 3, 4)
|
||||||
|
assert out["linear.weight"].shape == (8, 4)
|
||||||
|
assert out["norm.weight"].shape == (8,)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Wan2.1 Conversion Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWan21Convert:
|
||||||
|
"""Tests for Wan2.1 conversion support."""
|
||||||
|
|
||||||
|
def test_auto_detect_wan21(self, tmp_path):
|
||||||
|
"""Auto-detect single-model directory as Wan2.1."""
|
||||||
|
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
||||||
|
(tmp_path / "dummy.safetensors").touch()
|
||||||
|
# The auto-detect logic: no low_noise_model dir → 2.1
|
||||||
|
from pathlib import Path
|
||||||
|
low = tmp_path / "low_noise_model"
|
||||||
|
assert not low.exists()
|
||||||
|
# Simulates auto detection
|
||||||
|
version = "2.2" if low.exists() else "2.1"
|
||||||
|
assert version == "2.1"
|
||||||
|
|
||||||
|
def test_auto_detect_wan22(self, tmp_path):
|
||||||
|
"""Auto-detect dual-model directory as Wan2.2."""
|
||||||
|
(tmp_path / "low_noise_model").mkdir()
|
||||||
|
(tmp_path / "high_noise_model").mkdir()
|
||||||
|
from pathlib import Path
|
||||||
|
low = tmp_path / "low_noise_model"
|
||||||
|
assert low.exists()
|
||||||
|
version = "2.2" if low.exists() else "2.1"
|
||||||
|
assert version == "2.2"
|
||||||
|
|
||||||
|
def test_wan21_config_saved_correctly(self):
|
||||||
|
"""Verify config dict has correct fields for Wan2.1."""
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
d = config.to_dict()
|
||||||
|
assert d["model_version"] == "2.1"
|
||||||
|
assert d["dual_model"] is False
|
||||||
|
assert d["sample_steps"] == 50
|
||||||
|
assert d["sample_shift"] == 5.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Encoder Weight Sanitization Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSanitizeEncoderWeights:
|
||||||
|
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
||||||
|
|
||||||
|
def test_exclude_encoder_by_default(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
|
weights = {
|
||||||
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||||
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||||
|
assert "conv2.weight" in out
|
||||||
|
assert not any("encoder" in k or k.startswith("conv1") for k in out)
|
||||||
|
|
||||||
|
def test_include_encoder(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
|
||||||
|
weights = {
|
||||||
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||||
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||||
|
assert "encoder.conv1.weight" in out
|
||||||
|
assert "conv1.weight" in out
|
||||||
|
assert "conv2.weight" in out
|
||||||
238
tests/test_wan_generate.py
Normal file
238
tests/test_wan_generate.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""Tests for end-to-end generation and I2V mask construction."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wan_test_helpers import _make_tiny_config
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration: end-to-end tiny model forward pass
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEndToEnd:
|
||||||
|
"""End-to-end test with tiny model (no real weights needed)."""
|
||||||
|
|
||||||
|
def test_tiny_model_denoise_step(self):
|
||||||
|
"""Simulate one denoising step with tiny model."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
|
mx.random.seed(42)
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(5, shift=3.0)
|
||||||
|
|
||||||
|
latents = mx.random.normal((C, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
|
||||||
|
# One step
|
||||||
|
t = sched.timesteps[0]
|
||||||
|
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||||
|
latents_next = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||||
|
mx.eval(latents_next)
|
||||||
|
|
||||||
|
assert latents_next.shape == (C, F, H, W)
|
||||||
|
# Should differ from original noise
|
||||||
|
assert not np.allclose(np.array(latents_next), np.array(latents), atol=1e-5)
|
||||||
|
|
||||||
|
def test_tiny_model_full_loop(self):
|
||||||
|
"""Run a complete (tiny) diffusion loop."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
|
mx.random.seed(123)
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
num_steps = 3
|
||||||
|
sched.set_timesteps(num_steps, shift=3.0)
|
||||||
|
|
||||||
|
latents = mx.random.normal((C, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
t = sched.timesteps[i]
|
||||||
|
pred = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||||
|
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
assert latents.shape == (C, F, H, W)
|
||||||
|
assert not mx.any(mx.isnan(latents)).item(), "NaN in output"
|
||||||
|
assert not mx.any(mx.isinf(latents)).item(), "Inf in output"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# I2V Mask Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestI2VMask:
|
||||||
|
"""Tests for _build_i2v_mask."""
|
||||||
|
|
||||||
|
def test_mask_shapes(self):
|
||||||
|
from mlx_video.generate_wan import _build_i2v_mask
|
||||||
|
|
||||||
|
z_shape = (48, 5, 4, 4) # C, T, H, W
|
||||||
|
patch_size = (1, 2, 2)
|
||||||
|
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||||
|
assert mask.shape == z_shape
|
||||||
|
# Tokens: T=5, H/2=2, W/2=2 → 5*2*2 = 20
|
||||||
|
assert mask_tokens.shape == (1, 20)
|
||||||
|
|
||||||
|
def test_first_frame_zero(self):
|
||||||
|
from mlx_video.generate_wan import _build_i2v_mask
|
||||||
|
|
||||||
|
z_shape = (48, 5, 4, 4)
|
||||||
|
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
|
||||||
|
mx.eval(mask, mask_tokens)
|
||||||
|
# First temporal position should be 0
|
||||||
|
assert float(mask[:, 0, :, :].max()) == 0.0
|
||||||
|
# Rest should be 1
|
||||||
|
assert float(mask[:, 1:, :, :].min()) == 1.0
|
||||||
|
# First-frame tokens (T=0) should be 0 in mask_tokens
|
||||||
|
# With T=5, H'=2, W'=2: first 4 tokens are frame 0
|
||||||
|
assert float(mask_tokens[0, :4].max()) == 0.0
|
||||||
|
assert float(mask_tokens[0, 4:].min()) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestI2VMaskAlignment:
|
||||||
|
"""Tests that I2V mask works correctly with various aligned dimensions."""
|
||||||
|
|
||||||
|
def test_mask_with_ti2v_dimensions(self):
|
||||||
|
"""Mask should work with TI2V-5B typical dimensions."""
|
||||||
|
from mlx_video.generate_wan import _build_i2v_mask
|
||||||
|
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||||
|
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||||
|
z_shape = (48, 21, 44, 80)
|
||||||
|
patch_size = (1, 2, 2)
|
||||||
|
mask, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||||
|
mx.eval(mask, mask_tokens)
|
||||||
|
|
||||||
|
assert mask.shape == z_shape
|
||||||
|
assert float(mask[:, 0].max()) == 0.0
|
||||||
|
assert float(mask[:, 1:].min()) == 1.0
|
||||||
|
|
||||||
|
expected_tokens = 21 * 22 * 40 # T * (H/ph) * (W/pw)
|
||||||
|
assert mask_tokens.shape == (1, expected_tokens)
|
||||||
|
first_frame_tokens = 1 * 22 * 40 # pt=1
|
||||||
|
assert float(mask_tokens[0, :first_frame_tokens].max()) == 0.0
|
||||||
|
assert float(mask_tokens[0, first_frame_tokens:].min()) == 1.0
|
||||||
|
|
||||||
|
def test_mask_per_token_timestep(self):
|
||||||
|
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||||
|
from mlx_video.generate_wan import _build_i2v_mask
|
||||||
|
z_shape = (4, 3, 4, 4)
|
||||||
|
patch_size = (1, 2, 2)
|
||||||
|
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||||
|
mx.eval(mask_tokens)
|
||||||
|
|
||||||
|
timestep_val = 0.8
|
||||||
|
t_tokens = mask_tokens * timestep_val
|
||||||
|
mx.eval(t_tokens)
|
||||||
|
|
||||||
|
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
|
||||||
|
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
|
||||||
|
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dimension Alignment Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDimensionAlignment:
|
||||||
|
"""Tests for automatic dimension alignment in generate_wan."""
|
||||||
|
|
||||||
|
def test_already_aligned(self):
|
||||||
|
"""Dimensions already divisible by alignment factor should be unchanged."""
|
||||||
|
# patch_size=(1,2,2), vae_stride=(4,16,16) → align = 32
|
||||||
|
align_h = 2 * 16 # 32
|
||||||
|
align_w = 2 * 16 # 32
|
||||||
|
h, w = 704, 1280
|
||||||
|
assert h % align_h == 0
|
||||||
|
assert w % align_w == 0
|
||||||
|
h_aligned = (h // align_h) * align_h
|
||||||
|
w_aligned = (w // align_w) * align_w
|
||||||
|
assert h_aligned == h
|
||||||
|
assert w_aligned == w
|
||||||
|
|
||||||
|
def test_720p_rounds_down(self):
|
||||||
|
"""720p (1280x720) should round height to 704."""
|
||||||
|
align_h = 32
|
||||||
|
align_w = 32
|
||||||
|
h, w = 720, 1280
|
||||||
|
assert h % align_h != 0 # 720 not divisible by 32
|
||||||
|
h_aligned = (h // align_h) * align_h
|
||||||
|
w_aligned = (w // align_w) * align_w
|
||||||
|
assert h_aligned == 704
|
||||||
|
assert w_aligned == 1280
|
||||||
|
|
||||||
|
def test_1080p_rounds_down(self):
|
||||||
|
"""1080p (1920x1080) should round height to 1056."""
|
||||||
|
align = 32
|
||||||
|
h, w = 1080, 1920
|
||||||
|
assert h % align != 0
|
||||||
|
assert (h // align) * align == 1056
|
||||||
|
assert (w // align) * align == 1920
|
||||||
|
|
||||||
|
def test_odd_sizes(self):
|
||||||
|
"""Odd sizes should be safely rounded down."""
|
||||||
|
align = 32
|
||||||
|
for size in [100, 255, 513, 1023]:
|
||||||
|
aligned = (size // align) * align
|
||||||
|
assert aligned % align == 0
|
||||||
|
assert aligned <= size
|
||||||
|
assert aligned + align > size # closest lower multiple
|
||||||
|
|
||||||
|
def test_patchify_valid_after_alignment(self):
|
||||||
|
"""After alignment, patchify should succeed without reshape errors."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
# Simulate 720p-like scenario with tiny config
|
||||||
|
vae_stride = config.vae_stride # (4, 8, 8)
|
||||||
|
patch_size = config.patch_size # (1, 2, 2)
|
||||||
|
align_h = patch_size[1] * vae_stride[1]
|
||||||
|
align_w = patch_size[2] * vae_stride[2]
|
||||||
|
|
||||||
|
# Pick a height not divisible by alignment
|
||||||
|
raw_h = align_h * 3 + 5 # e.g. 53 for align=16
|
||||||
|
raw_w = align_w * 4
|
||||||
|
h = (raw_h // align_h) * align_h # rounds down
|
||||||
|
w = (raw_w // align_w) * align_w
|
||||||
|
|
||||||
|
C = config.in_dim
|
||||||
|
t_latent = 1
|
||||||
|
h_latent = h // vae_stride[1]
|
||||||
|
w_latent = w // vae_stride[2]
|
||||||
|
|
||||||
|
vid = mx.random.normal((C, t_latent, h_latent, w_latent))
|
||||||
|
patches, grid_size = model._patchify(vid)
|
||||||
|
mx.eval(patches)
|
||||||
|
assert patches.ndim == 3 # [1, L, dim]
|
||||||
|
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
|
||||||
|
|
||||||
|
def test_alignment_with_ti2v_config(self):
|
||||||
|
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan22_ti2v_5b()
|
||||||
|
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||||
|
align_w = config.patch_size[2] * config.vae_stride[2]
|
||||||
|
assert align_h == 32
|
||||||
|
assert align_w == 32
|
||||||
|
# 720 not divisible
|
||||||
|
assert 720 % align_h != 0
|
||||||
|
# 704 is
|
||||||
|
assert 704 % align_h == 0
|
||||||
332
tests/test_wan_model.py
Normal file
332
tests/test_wan_model.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
"""Tests for Wan model components."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from wan_test_helpers import _make_tiny_config
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sinusoidal Embedding Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSinusoidalEmbedding:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
pos = mx.arange(10).astype(mx.float32)
|
||||||
|
emb = sinusoidal_embedding_1d(256, pos)
|
||||||
|
mx.eval(emb)
|
||||||
|
assert emb.shape == (10, 256)
|
||||||
|
|
||||||
|
def test_position_zero(self):
|
||||||
|
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
pos = mx.array([0.0])
|
||||||
|
emb = sinusoidal_embedding_1d(64, pos)
|
||||||
|
mx.eval(emb)
|
||||||
|
emb_np = np.array(emb[0])
|
||||||
|
# First half is cos, should be 1 at position 0
|
||||||
|
np.testing.assert_allclose(emb_np[:32], 1.0, atol=1e-5)
|
||||||
|
# Second half is sin, should be 0 at position 0
|
||||||
|
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
|
||||||
|
|
||||||
|
def test_different_positions_differ(self):
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
pos = mx.array([0.0, 100.0, 999.0])
|
||||||
|
emb = sinusoidal_embedding_1d(128, pos)
|
||||||
|
mx.eval(emb)
|
||||||
|
emb_np = np.array(emb)
|
||||||
|
assert not np.allclose(emb_np[0], emb_np[1])
|
||||||
|
assert not np.allclose(emb_np[1], emb_np[2])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Head Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHead:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.model import Head
|
||||||
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||||
|
B, L = 1, 24
|
||||||
|
x = mx.random.normal((B, L, 64))
|
||||||
|
e = mx.random.normal((B, 64)) # time embedding: [B, dim]
|
||||||
|
out = head(x, e)
|
||||||
|
mx.eval(out)
|
||||||
|
expected_proj_dim = 16 * 1 * 2 * 2 # 64
|
||||||
|
assert out.shape == (B, L, expected_proj_dim)
|
||||||
|
|
||||||
|
def test_modulation_shape(self):
|
||||||
|
from mlx_video.models.wan.model import Head
|
||||||
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||||
|
assert head.modulation.shape == (1, 2, 64)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# WanModel (Tiny) Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWanModel:
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
|
||||||
|
def test_instantiation(self):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
||||||
|
assert num_params > 0
|
||||||
|
|
||||||
|
def test_patchify_shape(self):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
# Input: [C=4, F=1, H=4, W=4]
|
||||||
|
x = mx.random.normal((4, 1, 4, 4))
|
||||||
|
patches, grid_size = model._patchify(x)
|
||||||
|
mx.eval(patches)
|
||||||
|
# Patch size (1,2,2): F'=1, H'=2, W'=2
|
||||||
|
assert grid_size == (1, 2, 2)
|
||||||
|
assert patches.shape == (1, 1 * 2 * 2, config.dim)
|
||||||
|
|
||||||
|
def test_patchify_various_sizes(self):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
||||||
|
x = mx.random.normal((config.in_dim, f, h, w))
|
||||||
|
patches, (gf, gh, gw) = model._patchify(x)
|
||||||
|
mx.eval(patches)
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
assert gf == f // pt
|
||||||
|
assert gh == h // ph
|
||||||
|
assert gw == w // pw
|
||||||
|
assert patches.shape[1] == gf * gh * gw
|
||||||
|
|
||||||
|
def test_unpatchify_inverse(self):
|
||||||
|
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
C, F, H, W = config.in_dim, 2, 4, 6
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
F_out, H_out, W_out = F // pt, H // ph, W // pw
|
||||||
|
L = F_out * H_out * W_out
|
||||||
|
proj_dim = config.out_dim * pt * ph * pw
|
||||||
|
# Simulated head output
|
||||||
|
x = mx.random.normal((1, L, proj_dim))
|
||||||
|
out = model.unpatchify(x, [(F_out, H_out, W_out)])
|
||||||
|
mx.eval(out[0])
|
||||||
|
assert out[0].shape == (config.out_dim, F, H, W)
|
||||||
|
|
||||||
|
def test_forward_pass(self):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
x_list = [mx.random.normal((C, F, H, W))]
|
||||||
|
t = mx.array([500.0])
|
||||||
|
context = [mx.random.normal((6, config.text_dim))]
|
||||||
|
|
||||||
|
out = model(x_list, t, context, seq_len)
|
||||||
|
mx.eval(out[0])
|
||||||
|
assert len(out) == 1
|
||||||
|
assert out[0].shape == (C, F, H, W)
|
||||||
|
|
||||||
|
def test_forward_batch(self):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
pt, ph, pw = config.patch_size
|
||||||
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||||
|
|
||||||
|
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
||||||
|
t = mx.array([500.0, 200.0])
|
||||||
|
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
|
||||||
|
|
||||||
|
out = model(x_list, t, context, seq_len)
|
||||||
|
mx.eval(out[0], out[1])
|
||||||
|
assert len(out) == 2
|
||||||
|
for o in out:
|
||||||
|
assert o.shape == (C, F, H, W)
|
||||||
|
|
||||||
|
def test_output_is_float32(self):
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
config = _make_tiny_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||||
|
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
|
||||||
|
[mx.random.normal((4, config.text_dim))], seq_len)
|
||||||
|
mx.eval(out[0])
|
||||||
|
assert out[0].dtype == mx.float32
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Wan2.1 Model Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWan21Model:
|
||||||
|
"""Test tiny Wan2.1-style model (single model mode)."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
|
||||||
|
def _make_tiny_wan21_config(self):
|
||||||
|
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan21_t2v_14b()
|
||||||
|
# Override to tiny values
|
||||||
|
config.dim = 64
|
||||||
|
config.ffn_dim = 128
|
||||||
|
config.num_heads = 4
|
||||||
|
config.num_layers = 2
|
||||||
|
config.in_dim = 4
|
||||||
|
config.out_dim = 4
|
||||||
|
config.freq_dim = 32
|
||||||
|
config.text_dim = 32
|
||||||
|
config.text_len = 8
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _make_tiny_wan21_1_3b_config(self):
|
||||||
|
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig.wan21_t2v_1_3b()
|
||||||
|
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
||||||
|
config.dim = 48
|
||||||
|
config.ffn_dim = 96
|
||||||
|
config.num_heads = 4
|
||||||
|
config.num_layers = 2
|
||||||
|
config.in_dim = 4
|
||||||
|
config.out_dim = 4
|
||||||
|
config.freq_dim = 24
|
||||||
|
config.text_dim = 24
|
||||||
|
config.text_len = 8
|
||||||
|
return config
|
||||||
|
|
||||||
|
def test_wan21_tiny_model_forward(self):
|
||||||
|
"""Forward pass with Wan2.1 tiny config."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
|
config = self._make_tiny_wan21_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||||
|
|
||||||
|
latents = mx.random.normal((C, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
t = mx.array([500.0])
|
||||||
|
|
||||||
|
out = model([latents], t, [context], seq_len)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out[0].shape == (C, F, H, W)
|
||||||
|
|
||||||
|
def test_wan21_1_3b_tiny_model_forward(self):
|
||||||
|
"""Forward pass with Wan2.1 1.3B tiny config."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
|
||||||
|
config = self._make_tiny_wan21_1_3b_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||||
|
|
||||||
|
latents = mx.random.normal((C, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
t = mx.array([500.0])
|
||||||
|
|
||||||
|
out = model([latents], t, [context], seq_len)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out[0].shape == (C, F, H, W)
|
||||||
|
|
||||||
|
def test_wan21_single_model_loop(self):
|
||||||
|
"""Full diffusion loop with single model (Wan2.1 style)."""
|
||||||
|
from mlx_video.models.wan.model import WanModel
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
|
config = self._make_tiny_wan21_config()
|
||||||
|
model = WanModel(config)
|
||||||
|
|
||||||
|
C, F, H, W = config.in_dim, 1, 4, 4
|
||||||
|
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||||
|
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(config.sample_steps, shift=config.sample_shift)
|
||||||
|
|
||||||
|
# Use only 3 steps for speed
|
||||||
|
latents = mx.random.normal((C, F, H, W))
|
||||||
|
context = mx.random.normal((4, config.text_dim))
|
||||||
|
context_null = mx.zeros((4, config.text_dim))
|
||||||
|
gs = config.sample_guide_scale # Should be float for Wan2.1
|
||||||
|
|
||||||
|
assert isinstance(gs, float), "Wan2.1 guide_scale should be float"
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
t = sched.timesteps[i]
|
||||||
|
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||||
|
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
|
||||||
|
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
||||||
|
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
assert latents.shape == (C, F, H, W)
|
||||||
|
assert not mx.any(mx.isnan(latents)).item()
|
||||||
|
|
||||||
|
def test_wan21_vs_wan22_config_differences(self):
|
||||||
|
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
|
||||||
|
c21 = WanModelConfig.wan21_t2v_14b()
|
||||||
|
c22 = WanModelConfig.wan22_t2v_14b()
|
||||||
|
|
||||||
|
# Same architecture
|
||||||
|
assert c21.dim == c22.dim
|
||||||
|
assert c21.num_heads == c22.num_heads
|
||||||
|
assert c21.num_layers == c22.num_layers
|
||||||
|
|
||||||
|
# Different pipeline settings
|
||||||
|
assert c21.dual_model is False
|
||||||
|
assert c22.dual_model is True
|
||||||
|
assert isinstance(c21.sample_guide_scale, float)
|
||||||
|
assert isinstance(c22.sample_guide_scale, tuple)
|
||||||
|
assert c21.sample_shift != c22.sample_shift
|
||||||
|
assert c21.sample_steps != c22.sample_steps
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-Token Timestep Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPerTokenTimestep:
|
||||||
|
"""Tests for per-token sinusoidal embedding."""
|
||||||
|
|
||||||
|
def test_1d_unchanged(self):
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
pos = mx.array([0.0, 100.0, 500.0])
|
||||||
|
emb = sinusoidal_embedding_1d(256, pos)
|
||||||
|
assert emb.shape == (3, 256)
|
||||||
|
|
||||||
|
def test_2d_per_token(self):
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
|
||||||
|
emb = sinusoidal_embedding_1d(256, pos)
|
||||||
|
assert emb.shape == (2, 3, 256)
|
||||||
|
|
||||||
|
def test_consistency(self):
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
pos_1d = mx.array([0.0, 100.0])
|
||||||
|
emb_1d = sinusoidal_embedding_1d(256, pos_1d)
|
||||||
|
pos_2d = mx.array([[0.0, 100.0]])
|
||||||
|
emb_2d = sinusoidal_embedding_1d(256, pos_2d)
|
||||||
|
assert mx.array_equal(emb_1d[0], emb_2d[0, 0])
|
||||||
|
assert mx.array_equal(emb_1d[1], emb_2d[0, 1])
|
||||||
903
tests/test_wan_scheduler.py
Normal file
903
tests/test_wan_scheduler.py
Normal file
@@ -0,0 +1,903 @@
|
|||||||
|
"""Tests for Wan scheduler components."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Euler Scheduler Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFlowMatchEulerScheduler:
|
||||||
|
def test_initialization(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
assert sched.num_train_timesteps == 1000
|
||||||
|
assert sched.timesteps is None
|
||||||
|
assert sched.sigmas is None
|
||||||
|
|
||||||
|
def test_set_timesteps(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(40, shift=12.0)
|
||||||
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
|
assert sched.timesteps.shape == (40,)
|
||||||
|
assert sched.sigmas.shape == (41,) # 40 steps + terminal
|
||||||
|
|
||||||
|
def test_timesteps_decreasing(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(40, shift=12.0)
|
||||||
|
mx.eval(sched.timesteps)
|
||||||
|
ts = np.array(sched.timesteps)
|
||||||
|
# Timesteps should be monotonically decreasing
|
||||||
|
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
|
||||||
|
|
||||||
|
def test_sigmas_decreasing(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(20, shift=1.0)
|
||||||
|
mx.eval(sched.sigmas)
|
||||||
|
sigmas = np.array(sched.sigmas)
|
||||||
|
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
|
||||||
|
|
||||||
|
def test_terminal_sigma_is_zero(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(20, shift=5.0)
|
||||||
|
mx.eval(sched.sigmas)
|
||||||
|
np.testing.assert_allclose(np.array(sched.sigmas[-1]), 0.0, atol=1e-6)
|
||||||
|
|
||||||
|
def test_shift_effect(self):
|
||||||
|
"""Larger shift should push sigmas toward higher values."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched1 = FlowMatchEulerScheduler()
|
||||||
|
sched2 = FlowMatchEulerScheduler()
|
||||||
|
sched1.set_timesteps(20, shift=1.0)
|
||||||
|
sched2.set_timesteps(20, shift=12.0)
|
||||||
|
mx.eval(sched1.sigmas, sched2.sigmas)
|
||||||
|
mean1 = np.mean(np.array(sched1.sigmas[:-1]))
|
||||||
|
mean2 = np.mean(np.array(sched2.sigmas[:-1]))
|
||||||
|
assert mean2 > mean1, "Higher shift should push sigmas higher"
|
||||||
|
|
||||||
|
def test_step_euler(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(10, shift=1.0)
|
||||||
|
mx.eval(sched.sigmas)
|
||||||
|
|
||||||
|
sample = mx.ones((1, 4, 2, 2, 2))
|
||||||
|
velocity = mx.ones((1, 4, 2, 2, 2)) * 0.5
|
||||||
|
timestep = sched.timesteps[0]
|
||||||
|
|
||||||
|
sigma = float(np.array(sched.sigmas[0]))
|
||||||
|
sigma_next = float(np.array(sched.sigmas[1]))
|
||||||
|
|
||||||
|
result = sched.step(velocity, timestep, sample)
|
||||||
|
mx.eval(result)
|
||||||
|
|
||||||
|
# Euler: x_next = x + (sigma_next - sigma) * v
|
||||||
|
expected = 1.0 + (sigma_next - sigma) * 0.5
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(result).flatten()[0], expected, rtol=1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_step_index_increments(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
assert sched._step_index == 0
|
||||||
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
|
vel = mx.zeros((1, 1, 1, 1, 1))
|
||||||
|
sched.step(vel, sched.timesteps[0], sample)
|
||||||
|
assert sched._step_index == 1
|
||||||
|
sched.step(vel, sched.timesteps[1], sample)
|
||||||
|
assert sched._step_index == 2
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
|
vel = mx.zeros((1, 1, 1, 1, 1))
|
||||||
|
sched.step(vel, sched.timesteps[0], sample)
|
||||||
|
assert sched._step_index == 1
|
||||||
|
sched.reset()
|
||||||
|
assert sched._step_index == 0
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
|
||||||
|
def test_various_step_counts(self, steps):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(steps, shift=12.0)
|
||||||
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
|
assert sched.timesteps.shape == (steps,)
|
||||||
|
assert sched.sigmas.shape == (steps + 1,)
|
||||||
|
|
||||||
|
def test_full_denoise_loop(self):
|
||||||
|
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||||
|
sched = FlowMatchEulerScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 2, 1, 2, 2))
|
||||||
|
for i in range(5):
|
||||||
|
vel = mx.zeros_like(sample)
|
||||||
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
# With zero velocity, sample should remain unchanged
|
||||||
|
np.testing.assert_allclose(np.array(sample), 1.0, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared Sigma Schedule Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeSigmas:
|
||||||
|
"""Tests for the shared _compute_sigmas helper."""
|
||||||
|
|
||||||
|
def test_length(self):
|
||||||
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
sigmas = _compute_sigmas(20, shift=5.0)
|
||||||
|
assert len(sigmas) == 21 # num_steps + terminal
|
||||||
|
|
||||||
|
def test_terminal_zero(self):
|
||||||
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
sigmas = _compute_sigmas(10, shift=1.0)
|
||||||
|
assert sigmas[-1] == 0.0
|
||||||
|
|
||||||
|
def test_starts_at_one(self):
|
||||||
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
sigmas = _compute_sigmas(20, shift=5.0)
|
||||||
|
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-6)
|
||||||
|
|
||||||
|
def test_decreasing(self):
|
||||||
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
sigmas = _compute_sigmas(20, shift=5.0)
|
||||||
|
assert np.all(np.diff(sigmas) <= 0)
|
||||||
|
|
||||||
|
def test_matches_official_wan22(self):
|
||||||
|
"""Sigma schedule should match the official Wan2.2 get_sampling_sigmas."""
|
||||||
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
steps, shift = 50, 5.0
|
||||||
|
sigmas = _compute_sigmas(steps, shift)
|
||||||
|
# Official: sigma = linspace(1, 0, steps+1)[:steps]; sigma = shift*sigma/(1+(shift-1)*sigma)
|
||||||
|
official = np.linspace(1, 0, steps + 1)[:steps]
|
||||||
|
official = shift * official / (1 + (shift - 1) * official)
|
||||||
|
official = np.append(official, 0.0).astype(np.float32)
|
||||||
|
np.testing.assert_allclose(sigmas, official, atol=1e-6)
|
||||||
|
|
||||||
|
def test_shift_one_is_linear(self):
|
||||||
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
sigmas = _compute_sigmas(10, shift=1.0)
|
||||||
|
# With shift=1, f(sigma)=sigma, so schedule is linear from 1 to 0
|
||||||
|
expected = np.linspace(1, 0, 11).astype(np.float32)
|
||||||
|
np.testing.assert_allclose(sigmas, expected, atol=1e-6)
|
||||||
|
|
||||||
|
def test_all_schedulers_same_sigmas(self):
|
||||||
|
"""All three schedulers should produce identical sigma schedules."""
|
||||||
|
from mlx_video.models.wan.scheduler import (
|
||||||
|
FlowDPMPP2MScheduler,
|
||||||
|
FlowMatchEulerScheduler,
|
||||||
|
FlowUniPCScheduler,
|
||||||
|
)
|
||||||
|
scheds = [
|
||||||
|
FlowMatchEulerScheduler(1000),
|
||||||
|
FlowDPMPP2MScheduler(1000),
|
||||||
|
FlowUniPCScheduler(1000),
|
||||||
|
]
|
||||||
|
for s in scheds:
|
||||||
|
s.set_timesteps(20, shift=5.0)
|
||||||
|
mx.eval(*[s.sigmas for s in scheds])
|
||||||
|
ref = np.array(scheds[0].sigmas)
|
||||||
|
for s in scheds[1:]:
|
||||||
|
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
|
||||||
|
|
||||||
|
def test_all_schedulers_same_timesteps(self):
|
||||||
|
from mlx_video.models.wan.scheduler import (
|
||||||
|
FlowDPMPP2MScheduler,
|
||||||
|
FlowMatchEulerScheduler,
|
||||||
|
FlowUniPCScheduler,
|
||||||
|
)
|
||||||
|
scheds = [
|
||||||
|
FlowMatchEulerScheduler(1000),
|
||||||
|
FlowDPMPP2MScheduler(1000),
|
||||||
|
FlowUniPCScheduler(1000),
|
||||||
|
]
|
||||||
|
for s in scheds:
|
||||||
|
s.set_timesteps(30, shift=12.0)
|
||||||
|
mx.eval(*[s.timesteps for s in scheds])
|
||||||
|
ref = np.array(scheds[0].timesteps)
|
||||||
|
for s in scheds[1:]:
|
||||||
|
np.testing.assert_allclose(np.array(s.timesteps), ref, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DPM++ 2M Scheduler Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlowDPMPP2MScheduler:
|
||||||
|
def test_initialization(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
assert sched.num_train_timesteps == 1000
|
||||||
|
assert sched.lower_order_final is True
|
||||||
|
|
||||||
|
def test_set_timesteps(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(20, shift=5.0)
|
||||||
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
|
assert sched.timesteps.shape == (20,)
|
||||||
|
assert sched.sigmas.shape == (21,)
|
||||||
|
|
||||||
|
def test_step_index_increments(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 4, 1, 2, 2))
|
||||||
|
vel = mx.zeros_like(sample)
|
||||||
|
assert sched._step_index == 0
|
||||||
|
sched.step(vel, sched.timesteps[0], sample)
|
||||||
|
assert sched._step_index == 1
|
||||||
|
sched.step(vel, sched.timesteps[1], sample)
|
||||||
|
assert sched._step_index == 2
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
|
sched.step(mx.zeros_like(sample), 0, sample)
|
||||||
|
sched.reset()
|
||||||
|
assert sched._step_index == 0
|
||||||
|
assert sched._prev_x0 is None
|
||||||
|
|
||||||
|
def test_full_loop_finite(self):
|
||||||
|
"""Full loop with constant velocity should produce finite output."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(10, shift=1.0)
|
||||||
|
sample = mx.ones((1, 2, 1, 2, 2))
|
||||||
|
for i in range(10):
|
||||||
|
vel = mx.ones_like(sample) * 0.1
|
||||||
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
assert np.isfinite(np.array(sample)).all()
|
||||||
|
|
||||||
|
def test_first_step_is_first_order(self):
|
||||||
|
"""First step should use 1st-order (no prev_x0 available)."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(10, shift=5.0)
|
||||||
|
sample = mx.random.normal((1, 4, 2, 4, 4))
|
||||||
|
vel = mx.random.normal(sample.shape)
|
||||||
|
# Before first step, no prev_x0
|
||||||
|
assert sched._prev_x0 is None
|
||||||
|
result = sched.step(vel, sched.timesteps[0], sample)
|
||||||
|
mx.eval(result)
|
||||||
|
# After first step, prev_x0 should be set
|
||||||
|
assert sched._prev_x0 is not None
|
||||||
|
|
||||||
|
def test_second_step_uses_correction(self):
|
||||||
|
"""After first step, DPM++ should have stored prev_x0 for correction."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(10, shift=5.0)
|
||||||
|
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||||
|
vel = mx.random.normal(sample.shape)
|
||||||
|
# Step 1
|
||||||
|
sample = sched.step(vel, sched.timesteps[0], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
x0_after_first = sched._prev_x0
|
||||||
|
# Step 2
|
||||||
|
vel = mx.random.normal(sample.shape)
|
||||||
|
sample = sched.step(vel, sched.timesteps[1], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
# prev_x0 should have been updated
|
||||||
|
x0_after_second = sched._prev_x0
|
||||||
|
assert x0_after_second is not None
|
||||||
|
# The stored x0 should differ from the first step's
|
||||||
|
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
|
||||||
|
|
||||||
|
def test_denoise_to_target(self):
|
||||||
|
"""Perfect oracle should denoise to target with any solver."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(20, shift=5.0)
|
||||||
|
target = mx.zeros((1, 2, 1, 4, 4))
|
||||||
|
latents = mx.random.normal(target.shape)
|
||||||
|
for i in range(20):
|
||||||
|
sigma = float(sched.sigmas[i].item())
|
||||||
|
v = latents / max(sigma, 1e-6) # perfect velocity for target=0
|
||||||
|
latents = sched.step(v, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||||
|
def test_various_step_counts(self, steps):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(steps, shift=5.0)
|
||||||
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
|
assert sched.timesteps.shape == (steps,)
|
||||||
|
assert sched.sigmas.shape == (steps + 1,)
|
||||||
|
|
||||||
|
def test_terminal_sigma_produces_x0(self):
|
||||||
|
"""When sigma_next=0 the scheduler should return x0 directly."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
sched = FlowDPMPP2MScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
|
||||||
|
vel = mx.ones_like(sample) * 2.0
|
||||||
|
# Run through all steps; the last step has sigma_next=0
|
||||||
|
for i in range(5):
|
||||||
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
# Final value should be finite
|
||||||
|
assert np.isfinite(np.array(sample)).all()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# UniPC Scheduler Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlowUniPCScheduler:
|
||||||
|
def test_initialization(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
assert sched.num_train_timesteps == 1000
|
||||||
|
assert sched.solver_order == 2
|
||||||
|
assert sched.lower_order_final is True
|
||||||
|
|
||||||
|
def test_set_timesteps(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
sched.set_timesteps(30, shift=12.0)
|
||||||
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
|
assert sched.timesteps.shape == (30,)
|
||||||
|
assert sched.sigmas.shape == (31,)
|
||||||
|
|
||||||
|
def test_step_index_increments(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
|
vel = mx.zeros_like(sample)
|
||||||
|
assert sched._step_index == 0
|
||||||
|
sched.step(vel, 0, sample)
|
||||||
|
assert sched._step_index == 1
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 1, 1, 1, 1))
|
||||||
|
sched.step(mx.zeros_like(sample), 0, sample)
|
||||||
|
sched.reset()
|
||||||
|
assert sched._step_index == 0
|
||||||
|
assert sched._lower_order_nums == 0
|
||||||
|
assert sched._last_sample is None
|
||||||
|
assert all(m is None for m in sched._model_outputs)
|
||||||
|
|
||||||
|
def test_full_loop_finite(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
sched.set_timesteps(10, shift=1.0)
|
||||||
|
sample = mx.ones((1, 2, 1, 2, 2))
|
||||||
|
for i in range(10):
|
||||||
|
vel = mx.ones_like(sample) * 0.1
|
||||||
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
assert np.isfinite(np.array(sample)).all()
|
||||||
|
|
||||||
|
def test_corrector_not_applied_first_step(self):
|
||||||
|
"""First step should skip the corrector (no history)."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler(use_corrector=True)
|
||||||
|
sched.set_timesteps(10, shift=5.0)
|
||||||
|
sample = mx.random.normal((1, 4, 1, 2, 2))
|
||||||
|
vel = mx.random.normal(sample.shape)
|
||||||
|
# Before step 0: no last_sample
|
||||||
|
assert sched._last_sample is None
|
||||||
|
sched.step(vel, sched.timesteps[0], sample)
|
||||||
|
# After step 0: last_sample should be set for corrector on step 1
|
||||||
|
assert sched._last_sample is not None
|
||||||
|
|
||||||
|
def test_corrector_applied_after_first_step(self):
|
||||||
|
"""Steps after the first should use the corrector when enabled."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler(use_corrector=True)
|
||||||
|
sched.set_timesteps(10, shift=5.0)
|
||||||
|
sample = mx.random.normal((1, 2, 1, 4, 4))
|
||||||
|
for i in range(3):
|
||||||
|
vel = mx.random.normal(sample.shape)
|
||||||
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
# lower_order_nums should have increased
|
||||||
|
assert sched._lower_order_nums >= 2
|
||||||
|
|
||||||
|
def test_denoise_to_target(self):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
sched.set_timesteps(20, shift=5.0)
|
||||||
|
target = mx.zeros((1, 2, 1, 4, 4))
|
||||||
|
latents = mx.random.normal(target.shape)
|
||||||
|
for i in range(20):
|
||||||
|
sigma = float(sched.sigmas[i].item())
|
||||||
|
v = latents / max(sigma, 1e-6)
|
||||||
|
latents = sched.step(v, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
np.testing.assert_allclose(np.array(latents), 0.0, atol=1e-3)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||||
|
def test_various_step_counts(self, steps):
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
sched.set_timesteps(steps, shift=5.0)
|
||||||
|
mx.eval(sched.timesteps, sched.sigmas)
|
||||||
|
assert sched.timesteps.shape == (steps,)
|
||||||
|
assert sched.sigmas.shape == (steps + 1,)
|
||||||
|
|
||||||
|
def test_disable_corrector(self):
|
||||||
|
"""Disabling corrector on step 0 should still work without error."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
|
||||||
|
sched.set_timesteps(5, shift=1.0)
|
||||||
|
sample = mx.ones((1, 1, 1, 2, 2))
|
||||||
|
for i in range(5):
|
||||||
|
vel = mx.ones_like(sample) * 0.1
|
||||||
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
assert np.isfinite(np.array(sample)).all()
|
||||||
|
|
||||||
|
def test_solver_order_3(self):
|
||||||
|
"""Order 3 should work without error."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
|
||||||
|
sched.set_timesteps(10, shift=5.0)
|
||||||
|
sample = mx.random.normal((1, 2, 1, 2, 2))
|
||||||
|
for i in range(10):
|
||||||
|
vel = mx.random.normal(sample.shape)
|
||||||
|
sample = sched.step(vel, sched.timesteps[i], sample)
|
||||||
|
mx.eval(sample)
|
||||||
|
assert np.isfinite(np.array(sample)).all()
|
||||||
|
|
||||||
|
def test_corrector_rhos_c_not_hardcoded(self):
|
||||||
|
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
|
||||||
|
import math
|
||||||
|
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
|
||||||
|
# rhos_c[0] (history) should be ~0.07, NOT 0.5
|
||||||
|
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
|
||||||
|
from mlx_video.models.wan.scheduler import _compute_sigmas
|
||||||
|
|
||||||
|
sigmas = _compute_sigmas(50, shift=5.0)
|
||||||
|
|
||||||
|
def _lambda(sigma):
|
||||||
|
if sigma >= 1.0:
|
||||||
|
return -math.inf
|
||||||
|
if sigma <= 0.0:
|
||||||
|
return math.inf
|
||||||
|
return math.log(1 - sigma) - math.log(sigma)
|
||||||
|
|
||||||
|
for step_idx in [5, 10, 25, 45]:
|
||||||
|
sigma_s0 = sigmas[step_idx - 1]
|
||||||
|
sigma_t = sigmas[step_idx]
|
||||||
|
lambda_s0 = _lambda(sigma_s0)
|
||||||
|
lambda_t = _lambda(sigma_t)
|
||||||
|
h = lambda_t - lambda_s0
|
||||||
|
hh = -h
|
||||||
|
|
||||||
|
sigma_sk = sigmas[step_idx - 2]
|
||||||
|
lambda_sk = _lambda(sigma_sk)
|
||||||
|
rk = (lambda_sk - lambda_s0) / h
|
||||||
|
rks = np.array([rk, 1.0])
|
||||||
|
|
||||||
|
h_phi_1 = math.expm1(hh)
|
||||||
|
B_h = h_phi_1
|
||||||
|
h_phi_k = h_phi_1 / hh - 1.0
|
||||||
|
factorial_i = 1
|
||||||
|
R_rows, b_vals = [], []
|
||||||
|
for j in range(1, 3):
|
||||||
|
R_rows.append(rks ** (j - 1))
|
||||||
|
b_vals.append(h_phi_k * factorial_i / B_h)
|
||||||
|
factorial_i *= j + 1
|
||||||
|
h_phi_k = h_phi_k / hh - 1.0 / factorial_i
|
||||||
|
R = np.stack(R_rows)
|
||||||
|
b = np.array(b_vals)
|
||||||
|
rhos_c = np.linalg.solve(R, b)
|
||||||
|
|
||||||
|
# History weight should be small (~0.07-0.09), not 0.5
|
||||||
|
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
|
||||||
|
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
|
||||||
|
# D1_t weight should be ~0.42-0.45, not 0.5
|
||||||
|
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scheduler Coherence Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSchedulerCoherence:
|
||||||
|
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
|
||||||
|
|
||||||
|
All three schedulers should agree on shared structure (sigma schedules,
|
||||||
|
first-step behavior) and converge to the same result given perfect
|
||||||
|
velocity oracles, even though they use different update rules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_schedulers(steps=10, shift=5.0):
|
||||||
|
from mlx_video.models.wan.scheduler import (
|
||||||
|
FlowDPMPP2MScheduler,
|
||||||
|
FlowMatchEulerScheduler,
|
||||||
|
FlowUniPCScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheds = {
|
||||||
|
"euler": FlowMatchEulerScheduler(),
|
||||||
|
"dpm++": FlowDPMPP2MScheduler(),
|
||||||
|
"unipc": FlowUniPCScheduler(),
|
||||||
|
}
|
||||||
|
for s in scheds.values():
|
||||||
|
s.set_timesteps(steps, shift=shift)
|
||||||
|
return scheds
|
||||||
|
|
||||||
|
def test_identical_sigma_schedules(self):
|
||||||
|
"""All schedulers must use the same sigma schedule."""
|
||||||
|
scheds = self._make_schedulers(20, shift=5.0)
|
||||||
|
ref = np.array(scheds["euler"].sigmas)
|
||||||
|
for name in ("dpm++", "unipc"):
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(scheds[name].sigmas),
|
||||||
|
ref,
|
||||||
|
atol=1e-6,
|
||||||
|
err_msg=f"{name} sigma schedule differs from Euler",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_identical_timesteps(self):
|
||||||
|
"""All schedulers must produce the same timestep sequence."""
|
||||||
|
scheds = self._make_schedulers(20, shift=5.0)
|
||||||
|
ref = np.array(scheds["euler"].timesteps)
|
||||||
|
for name in ("dpm++", "unipc"):
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(scheds[name].timesteps),
|
||||||
|
ref,
|
||||||
|
atol=1e-6,
|
||||||
|
err_msg=f"{name} timesteps differ from Euler",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_first_step_matches_euler(self):
|
||||||
|
"""Step 0 (1st-order for all solvers) should match Euler exactly."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
shape = (1, 4, 1, 4, 4)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
vel = mx.random.normal(shape)
|
||||||
|
|
||||||
|
scheds = self._make_schedulers(10, shift=5.0)
|
||||||
|
results = {}
|
||||||
|
for name, sched in scheds.items():
|
||||||
|
r = sched.step(vel, sched.timesteps[0], noise)
|
||||||
|
mx.eval(r)
|
||||||
|
results[name] = np.array(r)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
results["dpm++"], results["euler"], atol=1e-5,
|
||||||
|
err_msg="DPM++ step 0 should match Euler",
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
results["unipc"], results["euler"], atol=1e-5,
|
||||||
|
err_msg="UniPC step 0 should match Euler",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_first_step_matches_across_shifts(self):
|
||||||
|
"""Step 0 should match Euler for different shift values."""
|
||||||
|
mx.random.seed(99)
|
||||||
|
shape = (1, 2, 1, 2, 2)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
vel = mx.random.normal(shape)
|
||||||
|
|
||||||
|
for shift in (1.0, 5.0, 12.0):
|
||||||
|
scheds = self._make_schedulers(10, shift=shift)
|
||||||
|
euler_r = scheds["euler"].step(vel, scheds["euler"].timesteps[0], noise)
|
||||||
|
dpm_r = scheds["dpm++"].step(vel, scheds["dpm++"].timesteps[0], noise)
|
||||||
|
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
|
||||||
|
mx.eval(euler_r, dpm_r, unipc_r)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(dpm_r), np.array(euler_r), atol=1e-5,
|
||||||
|
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(unipc_r), np.array(euler_r), atol=1e-5,
|
||||||
|
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_oracle_all_converge_to_target(self):
|
||||||
|
"""Given a perfect velocity oracle v=x/sigma, all solvers should
|
||||||
|
denoise to approximately zero (the target)."""
|
||||||
|
mx.random.seed(7)
|
||||||
|
shape = (1, 2, 1, 4, 4)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
|
||||||
|
for name, sched in self._make_schedulers(20, shift=5.0).items():
|
||||||
|
latents = noise
|
||||||
|
for i in range(20):
|
||||||
|
sigma = float(sched.sigmas[i].item())
|
||||||
|
v = latents / max(sigma, 1e-8)
|
||||||
|
latents = sched.step(v, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(latents), 0.0, atol=1e-3,
|
||||||
|
err_msg=f"{name} did not converge to target with oracle",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_oracle_higher_order_closer_to_target(self):
|
||||||
|
"""With few steps and a perfect oracle, higher-order solvers should
|
||||||
|
be at least as accurate as Euler."""
|
||||||
|
mx.random.seed(12)
|
||||||
|
shape = (1, 2, 1, 4, 4)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
steps = 5
|
||||||
|
|
||||||
|
errors = {}
|
||||||
|
for name, sched in self._make_schedulers(steps, shift=5.0).items():
|
||||||
|
latents = noise
|
||||||
|
for i in range(steps):
|
||||||
|
sigma = float(sched.sigmas[i].item())
|
||||||
|
v = latents / max(sigma, 1e-8)
|
||||||
|
latents = sched.step(v, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
errors[name] = float(mx.mean(mx.abs(latents)).item())
|
||||||
|
|
||||||
|
# Higher-order solvers should not be significantly worse than Euler
|
||||||
|
assert errors["dpm++"] <= errors["euler"] * 1.5, (
|
||||||
|
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||||
|
)
|
||||||
|
assert errors["unipc"] <= errors["euler"] * 1.5, (
|
||||||
|
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multistep_trajectory_similar_magnitude(self):
|
||||||
|
"""Over a full denoising loop with constant velocity, all solvers
|
||||||
|
should produce outputs of similar magnitude (not diverging)."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
shape = (1, 4, 1, 4, 4)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
steps = 20
|
||||||
|
|
||||||
|
final_means = {}
|
||||||
|
for name, sched in self._make_schedulers(steps, shift=5.0).items():
|
||||||
|
latents = noise
|
||||||
|
for i in range(steps):
|
||||||
|
vel = latents * 0.1
|
||||||
|
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
final_means[name] = float(mx.mean(mx.abs(latents)).item())
|
||||||
|
|
||||||
|
# All solvers should produce results within the same order of magnitude
|
||||||
|
vals = list(final_means.values())
|
||||||
|
ratio = max(vals) / max(min(vals), 1e-10)
|
||||||
|
assert ratio < 10.0, (
|
||||||
|
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_intermediate_values_finite(self):
|
||||||
|
"""Every intermediate latent value must be finite for all solvers."""
|
||||||
|
mx.random.seed(0)
|
||||||
|
shape = (1, 2, 1, 2, 2)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
|
||||||
|
for name, sched in self._make_schedulers(15, shift=5.0).items():
|
||||||
|
latents = noise
|
||||||
|
for i in range(15):
|
||||||
|
vel = mx.random.normal(shape)
|
||||||
|
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
assert np.isfinite(np.array(latents)).all(), (
|
||||||
|
f"{name} produced non-finite values at step {i}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lambda_boundary_values(self):
|
||||||
|
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
|
||||||
|
from mlx_video.models.wan.scheduler import (
|
||||||
|
FlowDPMPP2MScheduler,
|
||||||
|
FlowUniPCScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||||
|
assert cls._lambda(1.0) == -math.inf, (
|
||||||
|
f"{cls.__name__}._lambda(1.0) should be -inf"
|
||||||
|
)
|
||||||
|
assert cls._lambda(0.0) == math.inf, (
|
||||||
|
f"{cls.__name__}._lambda(0.0) should be +inf"
|
||||||
|
)
|
||||||
|
# Interior values should be finite
|
||||||
|
lam = cls._lambda(0.5)
|
||||||
|
assert math.isfinite(lam) and lam == 0.0, (
|
||||||
|
f"{cls.__name__}._lambda(0.5) should be 0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lambda_monotonically_decreasing(self):
|
||||||
|
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
|
||||||
|
|
||||||
|
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
|
||||||
|
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
|
||||||
|
for i in range(len(lambdas) - 1):
|
||||||
|
assert lambdas[i] > lambdas[i + 1], (
|
||||||
|
f"_lambda not decreasing: _lambda({sigmas[i]})={lambdas[i]} "
|
||||||
|
f"vs _lambda({sigmas[i+1]})={lambdas[i+1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_step0_is_ddim_formula(self):
|
||||||
|
"""At sigma=1.0, the DPM++/UniPC first step should reduce to the
|
||||||
|
DDIM formula: x_next = sigma_next * x + (1 - sigma_next) * x0."""
|
||||||
|
mx.random.seed(55)
|
||||||
|
shape = (1, 2, 1, 2, 2)
|
||||||
|
sample = mx.random.normal(shape)
|
||||||
|
vel = mx.random.normal(shape)
|
||||||
|
|
||||||
|
for steps, shift in [(10, 5.0), (20, 12.0)]:
|
||||||
|
scheds = self._make_schedulers(steps, shift=shift)
|
||||||
|
sigma_next = float(scheds["euler"].sigmas[1].item())
|
||||||
|
sigma_cur = float(scheds["euler"].sigmas[0].item())
|
||||||
|
assert abs(sigma_cur - 1.0) < 1e-6, "First sigma should be ~1.0"
|
||||||
|
|
||||||
|
x0 = sample - sigma_cur * vel
|
||||||
|
expected = sigma_next * sample + (1.0 - sigma_next) * x0
|
||||||
|
mx.eval(expected)
|
||||||
|
|
||||||
|
for name in ("dpm++", "unipc"):
|
||||||
|
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
|
||||||
|
mx.eval(result)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(result), np.array(expected), atol=1e-5,
|
||||||
|
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
|
||||||
|
def test_coherent_across_step_counts(self, steps):
|
||||||
|
"""All solvers should agree on step 0 regardless of total step count."""
|
||||||
|
mx.random.seed(77)
|
||||||
|
shape = (1, 2, 1, 2, 2)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
vel = mx.random.normal(shape)
|
||||||
|
|
||||||
|
scheds = self._make_schedulers(steps, shift=5.0)
|
||||||
|
results = {}
|
||||||
|
for name, sched in scheds.items():
|
||||||
|
r = sched.step(vel, sched.timesteps[0], noise)
|
||||||
|
mx.eval(r)
|
||||||
|
results[name] = np.array(r)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
results["dpm++"], results["euler"], atol=1e-5,
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
results["unipc"], results["euler"], atol=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dpmpp_unipc_agree_on_step1(self):
|
||||||
|
"""After warmup, DPM++ and UniPC step 1 should be similar
|
||||||
|
(both use 2nd-order corrections based on the same model outputs)."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
shape = (1, 4, 1, 4, 4)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
|
||||||
|
scheds = self._make_schedulers(10, shift=5.0)
|
||||||
|
# Run step 0 with same velocity
|
||||||
|
vel0 = mx.random.normal(shape)
|
||||||
|
for sched in scheds.values():
|
||||||
|
sched.step(vel0, sched.timesteps[0], noise)
|
||||||
|
|
||||||
|
# Run step 1 from same sample with same velocity
|
||||||
|
sample1 = scheds["euler"].step(vel0, scheds["euler"].timesteps[0], noise)
|
||||||
|
mx.eval(sample1)
|
||||||
|
vel1 = mx.random.normal(shape)
|
||||||
|
|
||||||
|
r_dpm = scheds["dpm++"].step(vel1, scheds["dpm++"].timesteps[1], sample1)
|
||||||
|
r_unipc = scheds["unipc"].step(vel1, scheds["unipc"].timesteps[1], sample1)
|
||||||
|
mx.eval(r_dpm, r_unipc)
|
||||||
|
|
||||||
|
# They won't be identical (different correction formulas) but should
|
||||||
|
# be in the same ballpark (within 50% of each other's magnitude)
|
||||||
|
mean_dpm = float(mx.mean(mx.abs(r_dpm)).item())
|
||||||
|
mean_unipc = float(mx.mean(mx.abs(r_unipc)).item())
|
||||||
|
ratio = max(mean_dpm, mean_unipc) / max(min(mean_dpm, mean_unipc), 1e-10)
|
||||||
|
assert ratio < 2.0, (
|
||||||
|
f"DPM++ and UniPC step 1 differ too much: "
|
||||||
|
f"DPM++={mean_dpm:.4f}, UniPC={mean_unipc:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_reset_makes_solvers_reproducible(self):
|
||||||
|
"""After reset(), running the same loop should produce identical output."""
|
||||||
|
mx.random.seed(42)
|
||||||
|
shape = (1, 2, 1, 2, 2)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
|
||||||
|
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
|
||||||
|
|
||||||
|
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
|
||||||
|
sched = cls()
|
||||||
|
sched.set_timesteps(5, shift=5.0)
|
||||||
|
|
||||||
|
# First run
|
||||||
|
latents = noise
|
||||||
|
for i in range(5):
|
||||||
|
vel = latents * 0.1
|
||||||
|
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
result1 = np.array(latents)
|
||||||
|
|
||||||
|
# Reset and run again
|
||||||
|
sched.reset()
|
||||||
|
latents = noise
|
||||||
|
for i in range(5):
|
||||||
|
vel = latents * 0.1
|
||||||
|
latents = sched.step(vel, sched.timesteps[i], latents)
|
||||||
|
mx.eval(latents)
|
||||||
|
result2 = np.array(latents)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(result1, result2, atol=1e-5,
|
||||||
|
err_msg=f"{cls.__name__} not reproducible after reset()")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# UniPC Corrector Default Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestUniPCCorrectorDefault:
|
||||||
|
"""Tests that the UniPC corrector is enabled by default,
|
||||||
|
matching official FlowUniPCMultistepScheduler behavior."""
|
||||||
|
|
||||||
|
def test_corrector_enabled_by_default(self):
|
||||||
|
"""Default construction should have corrector enabled."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
sched = FlowUniPCScheduler()
|
||||||
|
assert sched._use_corrector is True
|
||||||
|
|
||||||
|
def test_corrector_affects_output(self):
|
||||||
|
"""Corrector should produce different results than no corrector after step 1."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
mx.random.seed(42)
|
||||||
|
shape = (1, 4, 1, 4, 4)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
|
||||||
|
sched_corr = FlowUniPCScheduler(use_corrector=True)
|
||||||
|
sched_corr.set_timesteps(10, shift=5.0)
|
||||||
|
sched_no = FlowUniPCScheduler(use_corrector=False)
|
||||||
|
sched_no.set_timesteps(10, shift=5.0)
|
||||||
|
|
||||||
|
latent_corr = noise
|
||||||
|
latent_no = noise
|
||||||
|
for i in range(3):
|
||||||
|
vel = mx.random.normal(shape) * 0.1
|
||||||
|
latent_corr = sched_corr.step(vel, sched_corr.timesteps[i], latent_corr)
|
||||||
|
latent_no = sched_no.step(vel, sched_no.timesteps[i], latent_no)
|
||||||
|
mx.eval(latent_corr, latent_no)
|
||||||
|
|
||||||
|
diff = float(mx.abs(latent_corr - latent_no).max())
|
||||||
|
assert diff > 1e-6, f"Corrector had no effect (max diff={diff})"
|
||||||
|
|
||||||
|
def test_corrector_does_not_affect_first_step(self):
|
||||||
|
"""Step 0 should be identical regardless of corrector setting."""
|
||||||
|
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
|
||||||
|
mx.random.seed(42)
|
||||||
|
shape = (1, 4, 1, 4, 4)
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
vel = mx.random.normal(shape)
|
||||||
|
|
||||||
|
sched_corr = FlowUniPCScheduler(use_corrector=True)
|
||||||
|
sched_corr.set_timesteps(10, shift=5.0)
|
||||||
|
sched_no = FlowUniPCScheduler(use_corrector=False)
|
||||||
|
sched_no.set_timesteps(10, shift=5.0)
|
||||||
|
|
||||||
|
r1 = sched_corr.step(vel, sched_corr.timesteps[0], noise)
|
||||||
|
r2 = sched_no.step(vel, sched_no.timesteps[0], noise)
|
||||||
|
mx.eval(r1, r2)
|
||||||
|
np.testing.assert_allclose(np.array(r1), np.array(r2), atol=1e-6)
|
||||||
173
tests/test_wan_t5.py
Normal file
173
tests/test_wan_t5.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""Tests for T5 encoder components."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# T5 Encoder Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestT5LayerNorm:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||||
|
norm = T5LayerNorm(64)
|
||||||
|
x = mx.random.normal((2, 10, 64))
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (2, 10, 64)
|
||||||
|
|
||||||
|
def test_rms_normalization(self):
|
||||||
|
"""After T5LayerNorm with weight=1, RMS should be ~1."""
|
||||||
|
from mlx_video.models.wan.text_encoder import T5LayerNorm
|
||||||
|
norm = T5LayerNorm(128)
|
||||||
|
x = mx.random.normal((1, 5, 128)) * 5.0
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
out_np = np.array(out[0])
|
||||||
|
for i in range(5):
|
||||||
|
rms = np.sqrt(np.mean(out_np[i] ** 2))
|
||||||
|
np.testing.assert_allclose(rms, 1.0, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestT5RelativeEmbedding:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||||
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||||
|
out = rel_emb(10, 10)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
|
||||||
|
|
||||||
|
def test_asymmetric_lengths(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||||
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
|
||||||
|
out = rel_emb(8, 12)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 4, 8, 12)
|
||||||
|
|
||||||
|
def test_symmetry(self):
|
||||||
|
"""Position bias should have structure (not all zeros/random)."""
|
||||||
|
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
|
||||||
|
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
|
||||||
|
out = rel_emb(6, 6)
|
||||||
|
mx.eval(out)
|
||||||
|
out_np = np.array(out[0]) # [N, lq, lk]
|
||||||
|
# Diagonal elements (position i attending to position i) should be consistent
|
||||||
|
# (same relative distance = 0 for all diagonal elements)
|
||||||
|
for h in range(2):
|
||||||
|
diag = np.diag(out_np[h])
|
||||||
|
np.testing.assert_allclose(diag, diag[0], atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestT5Attention:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Attention
|
||||||
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
|
x = mx.random.normal((1, 10, 64))
|
||||||
|
out = attn(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 10, 64)
|
||||||
|
|
||||||
|
def test_no_scaling(self):
|
||||||
|
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Attention
|
||||||
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
|
# No scale attribute (unlike standard attention)
|
||||||
|
assert not hasattr(attn, "scale")
|
||||||
|
|
||||||
|
def test_with_position_bias(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
|
||||||
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
|
rel_emb = T5RelativeEmbedding(32, 4)
|
||||||
|
x = mx.random.normal((1, 10, 64))
|
||||||
|
pos_bias = rel_emb(10, 10)
|
||||||
|
out = attn(x, pos_bias=pos_bias)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 10, 64)
|
||||||
|
|
||||||
|
def test_with_mask(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Attention
|
||||||
|
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
|
||||||
|
x = mx.random.normal((1, 10, 64))
|
||||||
|
mask = mx.ones((1, 10))
|
||||||
|
mask = mx.concatenate([mask[:, :7], mx.zeros((1, 3))], axis=1)
|
||||||
|
out = attn(x, mask=mask)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 10, 64)
|
||||||
|
|
||||||
|
|
||||||
|
class TestT5FeedForward:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||||
|
ffn = T5FeedForward(64, 256)
|
||||||
|
x = mx.random.normal((1, 10, 64))
|
||||||
|
out = ffn(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 10, 64)
|
||||||
|
|
||||||
|
def test_gated_structure(self):
|
||||||
|
"""T5 FFN is gated: gate(x) * fc1(x)."""
|
||||||
|
from mlx_video.models.wan.text_encoder import T5FeedForward
|
||||||
|
ffn = T5FeedForward(32, 64)
|
||||||
|
assert hasattr(ffn, "gate_proj")
|
||||||
|
assert hasattr(ffn, "fc1")
|
||||||
|
assert hasattr(ffn, "fc2")
|
||||||
|
|
||||||
|
|
||||||
|
class TestT5Encoder:
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
encoder = T5Encoder(
|
||||||
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||||
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||||
|
)
|
||||||
|
ids = mx.array([[1, 5, 10, 0, 0]])
|
||||||
|
mask = mx.array([[1, 1, 1, 0, 0]])
|
||||||
|
out = encoder(ids, mask=mask)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 5, 64)
|
||||||
|
|
||||||
|
def test_shared_pos(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
encoder = T5Encoder(
|
||||||
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||||
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
|
||||||
|
)
|
||||||
|
assert encoder.pos_embedding is not None
|
||||||
|
for block in encoder.blocks:
|
||||||
|
assert block.pos_embedding is None
|
||||||
|
|
||||||
|
def test_per_layer_pos(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
encoder = T5Encoder(
|
||||||
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||||
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||||
|
)
|
||||||
|
assert encoder.pos_embedding is None
|
||||||
|
for block in encoder.blocks:
|
||||||
|
assert block.pos_embedding is not None
|
||||||
|
|
||||||
|
def test_param_count(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
encoder = T5Encoder(
|
||||||
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||||
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||||
|
)
|
||||||
|
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
|
||||||
|
assert num_params > 0
|
||||||
|
|
||||||
|
def test_without_mask(self):
|
||||||
|
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||||
|
encoder = T5Encoder(
|
||||||
|
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
|
||||||
|
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
|
||||||
|
)
|
||||||
|
ids = mx.array([[1, 5, 10]])
|
||||||
|
out = encoder(ids)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 3, 64)
|
||||||
160
tests/test_wan_transformer.py
Normal file
160
tests/test_wan_transformer.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""Tests for Wan transformer block components."""
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Transformer Block Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestWanFFN:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.transformer import WanFFN
|
||||||
|
ffn = WanFFN(64, 256)
|
||||||
|
x = mx.random.normal((2, 10, 64))
|
||||||
|
out = ffn(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (2, 10, 64)
|
||||||
|
|
||||||
|
def test_gelu_activation(self):
|
||||||
|
"""FFN should use GELU activation (non-linearity)."""
|
||||||
|
from mlx_video.models.wan.transformer import WanFFN
|
||||||
|
ffn = WanFFN(32, 128)
|
||||||
|
x = mx.ones((1, 1, 32)) * 2.0
|
||||||
|
out1 = ffn(x)
|
||||||
|
x2 = mx.ones((1, 1, 32)) * 4.0
|
||||||
|
out2 = ffn(x2)
|
||||||
|
mx.eval(out1, out2)
|
||||||
|
# Non-linear: 2x input should not give 2x output
|
||||||
|
assert not np.allclose(np.array(out2), np.array(out1) * 2.0, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWanAttentionBlock:
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
self.dim = 64
|
||||||
|
self.ffn_dim = 128
|
||||||
|
self.num_heads = 4
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
block = WanAttentionBlock(
|
||||||
|
self.dim, self.ffn_dim, self.num_heads,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
)
|
||||||
|
B, L = 1, 24
|
||||||
|
F, H, W = 2, 3, 4
|
||||||
|
x = mx.random.normal((B, L, self.dim))
|
||||||
|
e = mx.random.normal((B, L, 6, self.dim))
|
||||||
|
context = mx.random.normal((B, 16, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
|
||||||
|
out = block(
|
||||||
|
x, e, seq_lens=[L], grid_sizes=[(F, H, W)],
|
||||||
|
freqs=freqs, context=context,
|
||||||
|
)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (B, L, self.dim)
|
||||||
|
|
||||||
|
def test_modulation_shape(self):
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||||
|
assert block.modulation.shape == (1, 6, self.dim)
|
||||||
|
|
||||||
|
def test_with_cross_attn_norm(self):
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
block = WanAttentionBlock(
|
||||||
|
self.dim, self.ffn_dim, self.num_heads,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
)
|
||||||
|
assert block.norm3 is not None
|
||||||
|
|
||||||
|
def test_without_cross_attn_norm(self):
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
block = WanAttentionBlock(
|
||||||
|
self.dim, self.ffn_dim, self.num_heads,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
)
|
||||||
|
assert block.norm3 is None
|
||||||
|
|
||||||
|
def test_residual_connection(self):
|
||||||
|
"""Output should differ from zero even with small random init."""
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||||
|
B, L = 1, 8
|
||||||
|
F, H, W = 2, 2, 2
|
||||||
|
x = mx.ones((B, L, self.dim))
|
||||||
|
e = mx.zeros((B, L, 6, self.dim))
|
||||||
|
context = mx.random.normal((B, 4, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // self.num_heads)
|
||||||
|
|
||||||
|
out = block(x, e, [L], [(F, H, W)], freqs, context)
|
||||||
|
mx.eval(out)
|
||||||
|
# With residual connections, output should be close to input + corrections
|
||||||
|
assert not np.allclose(np.array(out), 0.0, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Float32 Modulation Precision Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFloat32Modulation:
|
||||||
|
"""Tests that modulation/gate operations are computed in float32,
|
||||||
|
matching official torch.amp.autocast('cuda', dtype=torch.float32)."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
mx.random.seed(42)
|
||||||
|
self.dim = 64
|
||||||
|
|
||||||
|
def test_block_modulation_in_float32(self):
|
||||||
|
"""Modulation param starts random but should be usable as float32."""
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
|
||||||
|
assert block.modulation.dtype == mx.float32
|
||||||
|
|
||||||
|
def test_block_output_float32_with_bf16_modulation_input(self):
|
||||||
|
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
|
||||||
|
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||||
|
from mlx_video.models.wan.rope import rope_params
|
||||||
|
block = WanAttentionBlock(self.dim, 128, 4)
|
||||||
|
B, L = 1, 8
|
||||||
|
x = mx.random.normal((B, L, self.dim))
|
||||||
|
e = mx.random.normal((B, L, 6, self.dim)).astype(mx.bfloat16)
|
||||||
|
ctx = mx.random.normal((B, 4, self.dim))
|
||||||
|
freqs = rope_params(1024, self.dim // 4)
|
||||||
|
|
||||||
|
out = block(x, e, [L], [(2, 2, 2)], freqs, ctx)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.dtype == mx.float32
|
||||||
|
assert np.isfinite(np.array(out)).all()
|
||||||
|
|
||||||
|
def test_head_modulation_float32(self):
|
||||||
|
"""Head modulation should be float32 even with bf16 e input."""
|
||||||
|
from mlx_video.models.wan.model import Head
|
||||||
|
head = Head(self.dim, 4, (1, 2, 2))
|
||||||
|
x = mx.random.normal((1, 8, self.dim))
|
||||||
|
e = mx.random.normal((1, 8, self.dim)).astype(mx.bfloat16)
|
||||||
|
out = head(x, e)
|
||||||
|
mx.eval(out)
|
||||||
|
assert np.isfinite(np.array(out.astype(mx.float32))).all()
|
||||||
|
|
||||||
|
def test_model_time_embedding_float32(self):
|
||||||
|
"""sinusoidal_embedding_1d output must be float32."""
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
t = mx.array([500.0])
|
||||||
|
emb = sinusoidal_embedding_1d(256, t)
|
||||||
|
mx.eval(emb)
|
||||||
|
assert emb.dtype == mx.float32
|
||||||
|
|
||||||
|
def test_model_per_token_time_embedding_float32(self):
|
||||||
|
"""Per-token time embeddings (I2V) should also be float32."""
|
||||||
|
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||||
|
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
||||||
|
emb = sinusoidal_embedding_1d(256, t)
|
||||||
|
mx.eval(emb)
|
||||||
|
assert emb.dtype == mx.float32
|
||||||
|
assert emb.shape == (1, 4, 256)
|
||||||
871
tests/test_wan_vae.py
Normal file
871
tests/test_wan_vae.py
Normal file
@@ -0,0 +1,871 @@
|
|||||||
|
"""Tests for Wan VAE 2.1 and 2.2 components."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# VAE 2.1 Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCausalConv3d:
|
||||||
|
def test_output_shape_stride1(self):
|
||||||
|
from mlx_video.models.wan.vae import CausalConv3d
|
||||||
|
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||||
|
# Initialize weights
|
||||||
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||||
|
x = mx.random.normal((1, 4, 3, 8, 8)) # [B, C, T, H, W]
|
||||||
|
out = conv(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# With causal padding and padding=1 on spatial, dims should be preserved
|
||||||
|
assert out.shape[0] == 1
|
||||||
|
assert out.shape[1] == 8 # out_channels
|
||||||
|
assert out.shape[2] == 3 # T preserved
|
||||||
|
assert out.shape[3] == 8 # H preserved
|
||||||
|
assert out.shape[4] == 8 # W preserved
|
||||||
|
|
||||||
|
def test_output_shape_kernel1(self):
|
||||||
|
from mlx_video.models.wan.vae import CausalConv3d
|
||||||
|
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||||
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||||
|
x = mx.random.normal((1, 4, 2, 4, 4))
|
||||||
|
out = conv(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 8, 2, 4, 4)
|
||||||
|
|
||||||
|
def test_causal_padding(self):
|
||||||
|
"""Causal conv should only use past/current frames, not future."""
|
||||||
|
from mlx_video.models.wan.vae import CausalConv3d
|
||||||
|
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||||
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||||
|
conv.bias = mx.zeros((2,))
|
||||||
|
# Create input where only the first frame has signal
|
||||||
|
x = mx.zeros((1, 2, 4, 4, 4))
|
||||||
|
x_np = np.zeros((1, 2, 4, 4, 4), dtype=np.float32)
|
||||||
|
x_np[:, :, 0, :, :] = 1.0
|
||||||
|
x = mx.array(x_np)
|
||||||
|
out = conv(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# Due to causal padding, the output at t=0 should only depend on t=0
|
||||||
|
|
||||||
|
|
||||||
|
class TestResidualBlock:
|
||||||
|
def test_same_dim(self):
|
||||||
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 8)
|
||||||
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 8, 2, 4, 4)
|
||||||
|
|
||||||
|
def test_different_dim(self):
|
||||||
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 16)
|
||||||
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 16, 2, 4, 4)
|
||||||
|
|
||||||
|
def test_shortcut_exists_when_dims_differ(self):
|
||||||
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 16)
|
||||||
|
assert block.shortcut is not None
|
||||||
|
|
||||||
|
def test_no_shortcut_when_dims_same(self):
|
||||||
|
from mlx_video.models.wan.vae import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 8)
|
||||||
|
assert block.shortcut is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAttentionBlock:
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.vae import AttentionBlock
|
||||||
|
block = AttentionBlock(8)
|
||||||
|
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 8, 2, 4, 4)
|
||||||
|
|
||||||
|
def test_residual_connection(self):
|
||||||
|
from mlx_video.models.wan.vae import AttentionBlock
|
||||||
|
block = AttentionBlock(8)
|
||||||
|
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(x, out)
|
||||||
|
# Residual: output should not be zero even with random init
|
||||||
|
assert np.abs(np.array(out)).max() > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestWanVAE:
|
||||||
|
def test_instantiation(self):
|
||||||
|
from mlx_video.models.wan.vae import WanVAE
|
||||||
|
vae = WanVAE(z_dim=16)
|
||||||
|
assert vae.z_dim == 16
|
||||||
|
assert vae.mean.shape == (16,)
|
||||||
|
assert vae.std.shape == (16,)
|
||||||
|
|
||||||
|
def test_normalization_stats(self):
|
||||||
|
from mlx_video.models.wan.vae import WanVAE, VAE_MEAN, VAE_STD
|
||||||
|
assert len(VAE_MEAN) == 16
|
||||||
|
assert len(VAE_STD) == 16
|
||||||
|
assert all(s > 0 for s in VAE_STD)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Wan2.2 VAE Component Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAE22CausalConv3d:
|
||||||
|
"""Tests for vae22.CausalConv3d (channels-last)."""
|
||||||
|
|
||||||
|
def test_output_shape_k3(self):
|
||||||
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||||
|
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||||
|
out = conv(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 4, 8, 8, 16)
|
||||||
|
|
||||||
|
def test_output_shape_k1(self):
|
||||||
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = conv(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
def test_temporal_causal(self):
|
||||||
|
"""Output at t=0 should not depend on t>0."""
|
||||||
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||||
|
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||||
|
conv.bias = mx.zeros(conv.bias.shape)
|
||||||
|
|
||||||
|
x = mx.zeros((1, 4, 4, 4, 2))
|
||||||
|
out_zero = conv(x)
|
||||||
|
mx.eval(out_zero)
|
||||||
|
t0_ref = np.array(out_zero[0, 0])
|
||||||
|
|
||||||
|
# Modify t=2..3; output at t=0 should be unchanged
|
||||||
|
x_mod = mx.concatenate([
|
||||||
|
x[:, :2],
|
||||||
|
mx.ones((1, 2, 4, 4, 2)),
|
||||||
|
], axis=1)
|
||||||
|
out_mod = conv(x_mod)
|
||||||
|
mx.eval(out_mod)
|
||||||
|
t0_mod = np.array(out_mod[0, 0])
|
||||||
|
np.testing.assert_allclose(t0_ref, t0_mod, atol=1e-5)
|
||||||
|
|
||||||
|
def test_channels_last_format(self):
|
||||||
|
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||||
|
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||||
|
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||||
|
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||||
|
out = conv(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape[-1] == 8 # last dim = out_channels
|
||||||
|
|
||||||
|
|
||||||
|
class TestRMSNorm:
|
||||||
|
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
norm = RMS_norm(16)
|
||||||
|
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == x.shape
|
||||||
|
|
||||||
|
def test_l2_normalization(self):
|
||||||
|
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||||
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
dim = 32
|
||||||
|
norm = RMS_norm(dim)
|
||||||
|
x = mx.random.normal((1, 1, 1, 1, dim)) * 5.0 # large values
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# After L2 norm * scale(=sqrt(dim)) * gamma(=1): ||out|| = sqrt(dim)
|
||||||
|
out_np = np.array(out).flatten()
|
||||||
|
l2 = np.linalg.norm(out_np)
|
||||||
|
np.testing.assert_allclose(l2, math.sqrt(dim), rtol=1e-3)
|
||||||
|
|
||||||
|
def test_scale_invariant(self):
|
||||||
|
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||||
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
norm = RMS_norm(8)
|
||||||
|
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||||
|
out1 = norm(x)
|
||||||
|
out2 = norm(x * 10.0)
|
||||||
|
mx.eval(out1, out2)
|
||||||
|
np.testing.assert_allclose(np.array(out1), np.array(out2), atol=1e-4)
|
||||||
|
|
||||||
|
def test_gamma_effect(self):
|
||||||
|
"""Non-unit gamma should scale output."""
|
||||||
|
from mlx_video.models.wan.vae22 import RMS_norm
|
||||||
|
norm = RMS_norm(4)
|
||||||
|
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||||
|
x = mx.ones((1, 1, 1, 1, 4))
|
||||||
|
out = norm(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# With gamma=2, each component is 2 * sqrt(4) * x/||x|| = 2 * 2 * 1/2 = 2
|
||||||
|
np.testing.assert_allclose(np.array(out).flatten(), 2.0, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDupUp3D:
|
||||||
|
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
|
||||||
|
|
||||||
|
def test_spatial_only(self):
|
||||||
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||||
|
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||||
|
out = up(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 3, 8, 8, 4)
|
||||||
|
|
||||||
|
def test_temporal_and_spatial(self):
|
||||||
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||||
|
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||||
|
out = up(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 6, 8, 8, 8)
|
||||||
|
|
||||||
|
def test_first_chunk_trims(self):
|
||||||
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||||
|
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||||
|
out_normal = up(x, first_chunk=False)
|
||||||
|
out_trimmed = up(x, first_chunk=True)
|
||||||
|
mx.eval(out_normal, out_trimmed)
|
||||||
|
# first_chunk removes factor_t-1=1 temporal frame
|
||||||
|
assert out_normal.shape[1] == 6
|
||||||
|
assert out_trimmed.shape[1] == 5
|
||||||
|
|
||||||
|
def test_no_temporal_first_chunk_noop(self):
|
||||||
|
from mlx_video.models.wan.vae22 import DupUp3D
|
||||||
|
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||||
|
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||||
|
out_normal = up(x, first_chunk=False)
|
||||||
|
out_trimmed = up(x, first_chunk=True)
|
||||||
|
mx.eval(out_normal, out_trimmed)
|
||||||
|
# factor_t=1, so first_chunk removes 0 frames
|
||||||
|
assert out_normal.shape == out_trimmed.shape
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAE22Resample:
|
||||||
|
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
|
||||||
|
|
||||||
|
def test_upsample2d_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
r = Resample(8, "upsample2d")
|
||||||
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = r(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
|
||||||
|
|
||||||
|
def test_upsample3d_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
r = Resample(8, "upsample3d")
|
||||||
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = r(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
|
||||||
|
|
||||||
|
def test_upsample3d_first_chunk(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
r = Resample(8, "upsample3d")
|
||||||
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = r(x, first_chunk=True)
|
||||||
|
mx.eval(out)
|
||||||
|
# first_chunk: 1 (bypass) + 2*(T-1) (interleaved) = 2T-1 = 3
|
||||||
|
assert out.shape == (1, 3, 8, 8, 8)
|
||||||
|
|
||||||
|
def test_upsample3d_first_chunk_single_frame(self):
|
||||||
|
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||||
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
r = Resample(8, "upsample3d")
|
||||||
|
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||||
|
x = mx.random.normal((1, 1, 4, 4, 8))
|
||||||
|
out = r(x, first_chunk=True)
|
||||||
|
mx.eval(out)
|
||||||
|
# Single frame with first_chunk: falls through to non-first path
|
||||||
|
# time_conv on 1 frame → 2 interleaved
|
||||||
|
assert out.shape == (1, 2, 8, 8, 8)
|
||||||
|
|
||||||
|
def test_upsample3d_first_frame_bypasses_time_conv(self):
|
||||||
|
"""First frame of first_chunk should NOT go through time_conv.
|
||||||
|
|
||||||
|
Official Wan2.2 skips time_conv for the very first frame entirely.
|
||||||
|
We verify this by checking that the first output frame depends only on
|
||||||
|
the first input frame (not on time_conv parameters).
|
||||||
|
"""
|
||||||
|
from mlx_video.models.wan.vae22 import Resample
|
||||||
|
C = 8
|
||||||
|
r = Resample(C, "upsample3d")
|
||||||
|
# Set time_conv weights to large values so its effect is detectable
|
||||||
|
r.time_conv.weight = mx.ones(r.time_conv.weight.shape) * 10.0
|
||||||
|
r.time_conv.bias = mx.zeros(r.time_conv.bias.shape)
|
||||||
|
# Set spatial conv to identity-like
|
||||||
|
r.resample_weight = mx.zeros(r.resample_weight.shape)
|
||||||
|
r.resample_bias = mx.zeros(r.resample_bias.shape)
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 3, 2, 2, C))
|
||||||
|
out = r(x, first_chunk=True)
|
||||||
|
mx.eval(out)
|
||||||
|
# Output: 5 frames (1 bypass + 4 interleaved from 2 remaining)
|
||||||
|
assert out.shape[1] == 5
|
||||||
|
|
||||||
|
# First frame should be spatial upsample of x[:, 0:1] only.
|
||||||
|
# Run just the first frame through spatial upsample for reference
|
||||||
|
first_only = x[:, 0:1]
|
||||||
|
ref = r._upsample2x(first_only.reshape(1, 2, 2, C))
|
||||||
|
ref = mx.pad(ref, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||||
|
ref = mx.conv_general(ref, r.resample_weight) + r.resample_bias
|
||||||
|
mx.eval(ref)
|
||||||
|
|
||||||
|
# Compare first output frame to reference
|
||||||
|
first_out = out[:, 0:1].reshape(1, out.shape[2], out.shape[3], C)
|
||||||
|
mx.eval(first_out)
|
||||||
|
assert mx.allclose(first_out, ref, atol=1e-5).item(), \
|
||||||
|
"First frame should bypass time_conv and match spatial-only upsample"
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAE22ResidualBlock:
|
||||||
|
"""Tests for vae22.ResidualBlock."""
|
||||||
|
|
||||||
|
def test_same_dim(self):
|
||||||
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 8)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 8)
|
||||||
|
|
||||||
|
def test_different_dim(self):
|
||||||
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 16)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
def test_shortcut_when_dims_differ(self):
|
||||||
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 16)
|
||||||
|
assert block.shortcut is not None
|
||||||
|
|
||||||
|
def test_no_shortcut_same_dim(self):
|
||||||
|
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||||
|
block = ResidualBlock(8, 8)
|
||||||
|
assert block.shortcut is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestResidualBlockLayers:
|
||||||
|
"""Tests for vae22.ResidualBlockLayers naming convention."""
|
||||||
|
|
||||||
|
def test_layer_names_no_underscore_prefix(self):
|
||||||
|
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||||
|
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||||
|
block = ResidualBlockLayers(8, 8)
|
||||||
|
params = dict(block.parameters())
|
||||||
|
# All param keys should use layer_N, not _layer_N
|
||||||
|
for key in params:
|
||||||
|
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
|
||||||
|
|
||||||
|
def test_has_expected_layers(self):
|
||||||
|
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||||
|
block = ResidualBlockLayers(8, 16)
|
||||||
|
assert hasattr(block, "layer_0") # first RMS_norm
|
||||||
|
assert hasattr(block, "layer_2") # first CausalConv3d
|
||||||
|
assert hasattr(block, "layer_3") # second RMS_norm
|
||||||
|
assert hasattr(block, "layer_6") # second CausalConv3d
|
||||||
|
|
||||||
|
def test_forward_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||||
|
block = ResidualBlockLayers(8, 16)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAE22AttentionBlock:
|
||||||
|
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||||
|
block = AttentionBlock(16)
|
||||||
|
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||||
|
block.proj_weight = mx.random.normal(block.proj_weight.shape) * 0.01
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
def test_residual_connection(self):
|
||||||
|
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||||
|
block = AttentionBlock(8)
|
||||||
|
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||||
|
block.proj_weight = mx.zeros(block.proj_weight.shape)
|
||||||
|
x = mx.ones((1, 1, 2, 2, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# With zero weights, attention output is 0 → residual is identity
|
||||||
|
np.testing.assert_allclose(np.array(out), np.array(x), atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHead22:
|
||||||
|
"""Tests for vae22.Head22 output head."""
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Head22
|
||||||
|
head = Head22(16, out_channels=12)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||||
|
out = head(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 12)
|
||||||
|
|
||||||
|
def test_layer_names_no_underscore(self):
|
||||||
|
"""Head layers must not use underscore prefix."""
|
||||||
|
from mlx_video.models.wan.vae22 import Head22
|
||||||
|
head = Head22(8)
|
||||||
|
assert hasattr(head, "layer_0") # RMS_norm
|
||||||
|
assert hasattr(head, "layer_2") # CausalConv3d
|
||||||
|
params = dict(head.parameters())
|
||||||
|
for key in params:
|
||||||
|
assert not key.startswith("_"), f"Head param {key} starts with underscore"
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnpatchify:
|
||||||
|
"""Tests for vae22._unpatchify."""
|
||||||
|
|
||||||
|
def test_basic_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import _unpatchify
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||||
|
out = _unpatchify(x, patch_size=2)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 8, 8, 3)
|
||||||
|
|
||||||
|
def test_patch_size_1_noop(self):
|
||||||
|
from mlx_video.models.wan.vae22 import _unpatchify
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||||
|
out = _unpatchify(x, patch_size=1)
|
||||||
|
mx.eval(out)
|
||||||
|
np.testing.assert_array_equal(np.array(out), np.array(x))
|
||||||
|
|
||||||
|
def test_preserves_content(self):
|
||||||
|
"""Unpatchify should be a lossless rearrangement."""
|
||||||
|
from mlx_video.models.wan.vae22 import _unpatchify
|
||||||
|
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||||
|
out = _unpatchify(x, patch_size=2)
|
||||||
|
mx.eval(out)
|
||||||
|
# All elements should be preserved
|
||||||
|
assert np.array(out).size == 48
|
||||||
|
assert set(np.array(out).flatten().tolist()) == set(range(48))
|
||||||
|
|
||||||
|
|
||||||
|
class TestDenormalizeLatents:
|
||||||
|
"""Tests for vae22.denormalize_latents."""
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||||
|
out = denormalize_latents(z)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 48)
|
||||||
|
|
||||||
|
def test_custom_mean_std(self):
|
||||||
|
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||||
|
z = mx.ones((1, 1, 1, 1, 4))
|
||||||
|
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||||
|
std = mx.array([0.5, 0.5, 0.5, 0.5])
|
||||||
|
out = denormalize_latents(z, mean=mean, std=std)
|
||||||
|
mx.eval(out)
|
||||||
|
# z * std + mean = 1*0.5 + [1,2,3,4] = [1.5, 2.5, 3.5, 4.5]
|
||||||
|
np.testing.assert_allclose(np.array(out).flatten(), [1.5, 2.5, 3.5, 4.5], atol=1e-5)
|
||||||
|
|
||||||
|
def test_uses_default_constants(self):
|
||||||
|
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD, denormalize_latents
|
||||||
|
# Should not raise with default constants
|
||||||
|
z = mx.zeros((1, 1, 1, 1, 48))
|
||||||
|
out = denormalize_latents(z)
|
||||||
|
mx.eval(out)
|
||||||
|
# z=0 → result = 0 * std + mean = mean
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(out).flatten(),
|
||||||
|
np.array(VAE22_MEAN).flatten(),
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAE22NormConstants:
|
||||||
|
"""Tests for VAE22_MEAN and VAE22_STD constants."""
|
||||||
|
|
||||||
|
def test_dimensions(self):
|
||||||
|
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
||||||
|
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||||
|
assert VAE22_MEAN.shape == (48,)
|
||||||
|
assert VAE22_STD.shape == (48,)
|
||||||
|
|
||||||
|
def test_std_positive(self):
|
||||||
|
from mlx_video.models.wan.vae22 import VAE22_STD
|
||||||
|
mx.eval(VAE22_STD)
|
||||||
|
assert (np.array(VAE22_STD) > 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
class TestWan22VAEDecoder:
|
||||||
|
"""Tests for the full Wan22VAEDecoder (tiny configuration)."""
|
||||||
|
|
||||||
|
def test_output_shape_small(self):
|
||||||
|
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
# Use very small dims to keep test fast
|
||||||
|
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||||
|
# Latent: [B=1, T=3, H=2, W=2, C=4]
|
||||||
|
# Expected: temporal 3→5→9→9→9 (two temporal upsamples), spatial 2→4→8→16
|
||||||
|
z = mx.random.normal((1, 3, 2, 2, 4)) * 0.1
|
||||||
|
out = dec(z)
|
||||||
|
mx.eval(out)
|
||||||
|
# Output should have 3 RGB channels and be clipped to [-1, 1]
|
||||||
|
assert out.shape[-1] == 3
|
||||||
|
assert out.ndim == 5
|
||||||
|
assert np.array(out).min() >= -1.0
|
||||||
|
assert np.array(out).max() <= 1.0
|
||||||
|
|
||||||
|
def test_output_clipped(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||||
|
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||||
|
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||||
|
out = dec(z)
|
||||||
|
mx.eval(out)
|
||||||
|
assert np.array(out).min() >= -1.0 - 1e-6
|
||||||
|
assert np.array(out).max() <= 1.0 + 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeWan22VAEWeights:
|
||||||
|
"""Tests for vae22.sanitize_wan22_vae_weights."""
|
||||||
|
|
||||||
|
def test_skip_encoder(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
weights = {
|
||||||
|
"encoder.layer.weight": mx.zeros((4,)),
|
||||||
|
"conv1.weight": mx.zeros((4,)),
|
||||||
|
"decoder.conv1.bias": mx.zeros((4,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan22_vae_weights(weights)
|
||||||
|
assert "encoder.layer.weight" not in out
|
||||||
|
assert "conv1.weight" not in out
|
||||||
|
assert "decoder.conv1.bias" in out
|
||||||
|
|
||||||
|
def test_sequential_index_remapping(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
weights = {
|
||||||
|
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||||
|
"decoder.upsamples.0.upsamples.0.residual.6.bias": mx.zeros((8,)),
|
||||||
|
"decoder.head.0.gamma": mx.ones((4,)),
|
||||||
|
"decoder.head.2.bias": mx.zeros((12,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan22_vae_weights(weights)
|
||||||
|
assert "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma" in out
|
||||||
|
assert "decoder.upsamples.0.upsamples.0.residual.layer_6.bias" in out
|
||||||
|
assert "decoder.head.layer_0.gamma" in out
|
||||||
|
assert "decoder.head.layer_2.bias" in out
|
||||||
|
|
||||||
|
def test_resample_conv_remapping(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
weights = {
|
||||||
|
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||||
|
"decoder.upsamples.1.upsamples.3.resample.1.bias": mx.zeros((8,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan22_vae_weights(weights)
|
||||||
|
assert "decoder.upsamples.1.upsamples.3.resample_weight" in out
|
||||||
|
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
|
||||||
|
|
||||||
|
def test_attention_remapping(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
weights = {
|
||||||
|
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||||
|
"decoder.middle.1.to_qkv.bias": mx.zeros((24,)),
|
||||||
|
"decoder.middle.1.proj.weight": mx.zeros((8, 8, 1, 1)),
|
||||||
|
"decoder.middle.1.proj.bias": mx.zeros((8,)),
|
||||||
|
}
|
||||||
|
out = sanitize_wan22_vae_weights(weights)
|
||||||
|
assert "decoder.middle.1.to_qkv_weight" in out
|
||||||
|
assert "decoder.middle.1.to_qkv_bias" in out
|
||||||
|
assert "decoder.middle.1.proj_weight" in out
|
||||||
|
assert "decoder.middle.1.proj_bias" in out
|
||||||
|
|
||||||
|
def test_conv3d_transpose(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||||
|
w = mx.zeros((16, 8, 3, 3, 3))
|
||||||
|
weights = {"decoder.conv1.weight": w}
|
||||||
|
out = sanitize_wan22_vae_weights(weights)
|
||||||
|
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
|
||||||
|
|
||||||
|
def test_conv2d_transpose(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||||
|
w = mx.zeros((8, 8, 3, 3))
|
||||||
|
weights = {"decoder.upsamples.0.upsamples.2.resample.1.weight": w}
|
||||||
|
out = sanitize_wan22_vae_weights(weights)
|
||||||
|
key = "decoder.upsamples.0.upsamples.2.resample_weight"
|
||||||
|
assert out[key].shape == (8, 3, 3, 8)
|
||||||
|
|
||||||
|
def test_gamma_squeeze(self):
|
||||||
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||||
|
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||||
|
w = mx.ones((16, 1, 1, 1))
|
||||||
|
weights = {"decoder.upsamples.0.upsamples.0.residual.0.gamma": w}
|
||||||
|
out = sanitize_wan22_vae_weights(weights)
|
||||||
|
key = "decoder.upsamples.0.upsamples.0.residual.layer_0.gamma"
|
||||||
|
assert out[key].shape == (16,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpResidualBlock:
|
||||||
|
"""Tests for vae22.Up_ResidualBlock."""
|
||||||
|
|
||||||
|
def test_no_upsample(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||||
|
block = Up_ResidualBlock(8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# No upsample: same shape
|
||||||
|
assert out.shape == (1, 2, 4, 4, 8)
|
||||||
|
|
||||||
|
def test_spatial_upsample(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||||
|
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# 2x spatial upsample, no temporal
|
||||||
|
assert out.shape == (1, 2, 8, 8, 4)
|
||||||
|
|
||||||
|
def test_spatial_temporal_upsample(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||||
|
block = Up_ResidualBlock(8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True)
|
||||||
|
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# 2x spatial + 2x temporal
|
||||||
|
assert out.shape == (1, 4, 8, 8, 4)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatchify:
|
||||||
|
"""Tests for _patchify and _unpatchify round-trip."""
|
||||||
|
|
||||||
|
def test_roundtrip(self):
|
||||||
|
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 1, 64, 64, 3))
|
||||||
|
p = _patchify(x, patch_size=2)
|
||||||
|
assert p.shape == (1, 1, 32, 32, 12)
|
||||||
|
back = _unpatchify(p, patch_size=2)
|
||||||
|
assert back.shape == x.shape
|
||||||
|
assert float(mx.abs(x - back).max()) == 0.0
|
||||||
|
|
||||||
|
def test_identity_patch_1(self):
|
||||||
|
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 2, 8, 8, 3))
|
||||||
|
assert _patchify(x, patch_size=1).shape == x.shape
|
||||||
|
assert _unpatchify(x, patch_size=1).shape == x.shape
|
||||||
|
|
||||||
|
|
||||||
|
class TestAvgDown3D:
|
||||||
|
"""Tests for AvgDown3D downsampling."""
|
||||||
|
|
||||||
|
def test_spatial_only(self):
|
||||||
|
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||||
|
|
||||||
|
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
|
||||||
|
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||||
|
out = down(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
def test_temporal_and_spatial(self):
|
||||||
|
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||||
|
|
||||||
|
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
|
||||||
|
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||||
|
out = down(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
def test_single_frame(self):
|
||||||
|
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||||
|
|
||||||
|
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
|
||||||
|
x = mx.random.normal((1, 1, 8, 8, 8))
|
||||||
|
out = down(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# T=1 with factor_t=2: pads to T=2 then averages → T=1
|
||||||
|
assert out.shape == (1, 1, 4, 4, 8)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownResidualBlock:
|
||||||
|
"""Tests for Down_ResidualBlock."""
|
||||||
|
|
||||||
|
def test_no_downsample(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||||
|
|
||||||
|
block = Down_ResidualBlock(8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False)
|
||||||
|
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 8, 8, 8)
|
||||||
|
|
||||||
|
def test_spatial_downsample(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||||
|
|
||||||
|
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True)
|
||||||
|
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
def test_spatial_temporal_downsample(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||||
|
|
||||||
|
block = Down_ResidualBlock(8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True)
|
||||||
|
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||||
|
out = block(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape == (1, 2, 4, 4, 16)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncoder3d:
|
||||||
|
"""Tests for Encoder3d."""
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Encoder3d
|
||||||
|
|
||||||
|
enc = Encoder3d(dim=16, z_dim=8)
|
||||||
|
x = mx.random.normal((1, 1, 16, 16, 12))
|
||||||
|
mx.eval(enc.parameters())
|
||||||
|
out = enc(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# 3 spatial downsamples ÷8: 16→2
|
||||||
|
assert out.shape == (1, 1, 2, 2, 8)
|
||||||
|
|
||||||
|
def test_multi_frame(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Encoder3d
|
||||||
|
|
||||||
|
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||||
|
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||||
|
mx.eval(enc.parameters())
|
||||||
|
out = enc(x)
|
||||||
|
mx.eval(out)
|
||||||
|
# T: 5→3 (1st t_down) →2 (2nd t_down), spatial ÷8
|
||||||
|
assert out.shape[2:] == (2, 2, 8)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWan22VAEEncoder:
|
||||||
|
"""Tests for Wan22VAEEncoder wrapper."""
|
||||||
|
|
||||||
|
def test_output_shape(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||||
|
|
||||||
|
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||||
|
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
|
||||||
|
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||||
|
mx.eval(enc.parameters())
|
||||||
|
z = enc(img)
|
||||||
|
mx.eval(z)
|
||||||
|
assert z.shape == (1, 1, 2, 2, 48)
|
||||||
|
|
||||||
|
def test_full_dim(self):
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||||
|
|
||||||
|
enc = Wan22VAEEncoder(z_dim=48, dim=160)
|
||||||
|
img = mx.random.normal((1, 1, 64, 64, 3))
|
||||||
|
mx.eval(enc.parameters())
|
||||||
|
z = enc(img)
|
||||||
|
mx.eval(z)
|
||||||
|
# 64 / 16 = 4 (vae stride 16×)
|
||||||
|
assert z.shape == (1, 1, 4, 4, 48)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeLatents:
|
||||||
|
"""Tests for normalize/denormalize latent roundtrip."""
|
||||||
|
|
||||||
|
def test_roundtrip(self):
|
||||||
|
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
|
||||||
|
|
||||||
|
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||||
|
z_norm = normalize_latents(z)
|
||||||
|
z_back = denormalize_latents(z_norm)
|
||||||
|
mx.eval(z_back)
|
||||||
|
assert float(mx.abs(z - z_back).max()) < 1e-4
|
||||||
|
|
||||||
|
|
||||||
|
class TestVAEEncoderTemporalOrder:
|
||||||
|
"""Tests that VAE encoder uses (False, True, True) temporal downsample order,
|
||||||
|
matching official Wan2.2 vae2_2.py."""
|
||||||
|
|
||||||
|
def test_encoder_temporal_downsample_pattern(self):
|
||||||
|
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||||
|
from mlx_video.models.wan.vae22 import Encoder3d
|
||||||
|
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||||
|
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||||
|
mx.eval(enc.parameters())
|
||||||
|
out = enc(x)
|
||||||
|
mx.eval(out)
|
||||||
|
assert out.shape[1] == 2
|
||||||
|
|
||||||
|
def test_wrapper_uses_correct_pattern(self):
|
||||||
|
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEEncoder, Resample
|
||||||
|
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||||
|
down_blocks = enc.encoder.downsamples
|
||||||
|
found_modes = []
|
||||||
|
for block in down_blocks:
|
||||||
|
for layer in block.downsamples:
|
||||||
|
if isinstance(layer, Resample):
|
||||||
|
found_modes.append(layer.mode)
|
||||||
|
# First spatial-only, then two with temporal
|
||||||
|
assert found_modes[0] == "downsample2d"
|
||||||
|
assert any("3d" in m for m in found_modes)
|
||||||
|
|
||||||
|
def test_single_frame_encoder(self):
|
||||||
|
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||||
|
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||||
|
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||||
|
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||||
|
mx.eval(enc.parameters())
|
||||||
|
z = enc(img)
|
||||||
|
mx.eval(z)
|
||||||
|
assert z.shape[1] == 1
|
||||||
|
assert z.shape[-1] == 48
|
||||||
|
|
||||||
|
def test_wrong_order_gives_different_result(self):
|
||||||
|
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||||
|
from mlx_video.models.wan.vae22 import Encoder3d
|
||||||
|
enc_correct = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||||
|
enc_wrong = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||||
|
mx.eval(enc_correct.parameters())
|
||||||
|
mx.eval(enc_wrong.parameters())
|
||||||
|
|
||||||
|
out_correct = enc_correct(x)
|
||||||
|
out_wrong = enc_wrong(x)
|
||||||
|
mx.eval(out_correct, out_wrong)
|
||||||
|
|
||||||
|
# Both give T=2 but spatial processing path differs
|
||||||
|
assert out_correct.shape[1] == 2
|
||||||
|
assert out_wrong.shape[1] == 2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
19
tests/wan_test_helpers.py
Normal file
19
tests/wan_test_helpers.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Shared test helpers for Wan test modules."""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tiny_config():
|
||||||
|
"""Create a tiny WanModelConfig for testing."""
|
||||||
|
from mlx_video.models.wan.config import WanModelConfig
|
||||||
|
config = WanModelConfig()
|
||||||
|
# Override to tiny values
|
||||||
|
config.dim = 64
|
||||||
|
config.ffn_dim = 128
|
||||||
|
config.num_heads = 4
|
||||||
|
config.num_layers = 2
|
||||||
|
config.in_dim = 4
|
||||||
|
config.out_dim = 4
|
||||||
|
config.patch_size = (1, 2, 2)
|
||||||
|
config.freq_dim = 32
|
||||||
|
config.text_dim = 32
|
||||||
|
config.text_len = 8
|
||||||
|
return config
|
||||||
Reference in New Issue
Block a user