feat(wan): Add diagnostic scripts and porting guide

This commit is contained in:
Daniel
2026-03-06 20:46:43 +01:00
parent 9bdda9f22e
commit 967218b7c1
3 changed files with 1565 additions and 0 deletions

911
docs/PORTING-GUIDE.md Normal file
View File

@@ -0,0 +1,911 @@
# Porting Diffusion Video Models to MLX: Lessons Learned
A practical guide distilled from porting Wan2.1/2.2 (1.3B14B) and Helios 14B DiT
video generation models from PyTorch to MLX on Apple Silicon. These lessons apply
broadly to any diffusion-based video (or image) model port.
---
## Table of Contents
1. [Debugging Methodology](#1-debugging-methodology)
2. [Precision & Dtype Pitfalls](#2-precision--dtype-pitfalls)
3. [MLX-Specific Gotchas](#3-mlx-specific-gotchas)
4. [Autoregressive Chunk Boundaries](#4-autoregressive-chunk-boundaries)
5. [VAE Decoder Artifacts](#5-vae-decoder-artifacts)
6. [Scheduler & Timestep Issues](#6-scheduler--timestep-issues)
7. [Weight Conversion](#7-weight-conversion)
8. [Text Conditioning Failures](#8-text-conditioning-failures)
9. [Position Encodings (RoPE)](#9-position-encodings-rope)
10. [Multi-Stage / Pyramid Pipelines](#10-multi-stage--pyramid-pipelines)
11. [Common Symptoms → Root Causes](#11-common-symptoms--root-causes)
12. [Verification Checklist](#12-verification-checklist)
13. [Diagnostic Tools](#13-diagnostic-tools)
---
## 1. Debugging Methodology
### Component isolation first
Never debug the full pipeline. Test each component in isolation:
1. **Text encoder** — Does it produce embeddings with reasonable statistics? (std > 0.01)
2. **Scheduler** — Do sigma/timestep values match the reference exactly?
3. **Transformer** — Does a single forward pass match the reference? (cosine similarity > 0.999)
4. **VAE decoder** — Feed reference latents into your VAE. Does the output look correct?
If every component matches individually but the pipeline fails, the bug is in
**orchestration** — how components are wired together.
### Statistical fingerprinting
Track per-step statistics through the diffusion loop:
```python
# After each denoising step
print(f"step {i}: mean={latent.mean():.6f} std={latent.std():.6f} "
f"min={latent.min():.4f} max={latent.max():.4f}")
```
**What to look for:**
- **Progressive mean drift** (e.g., -0.002 → -0.040 → -0.123) signals accumulating errors
- **Collapsing std** (std dropping toward 0) signals broken conditioning or wrong noise schedule
- **Exploding values** signal wrong sigma scaling or scheduler formula
### Cross-framework numerical comparison
The most powerful debugging tool: save intermediate tensors from your MLX pipeline,
feed them to the PyTorch reference, compare outputs.
```python
# In MLX pipeline, save inputs before transformer call
mx.save("debug_inputs.npz", {"latent": latent, "timestep": t, "text_emb": text_emb})
# In PyTorch script, load and compare
inputs = np.load("debug_inputs.npz")
mlx_out = np.load("debug_output.npz")["flow"]
pt_out = reference_model(torch.from_numpy(inputs["latent"]), ...)
cos_sim = F.cosine_similarity(pt_out.flatten(), torch.from_numpy(mlx_out).flatten(), dim=0)
# cos_sim > 0.999 = model is correct; bug is elsewhere
# cos_sim < 0.99 = model has a bug; compare per-layer
```
### Ablation testing
When a pipeline has multiple "fixes" or features, disable them one at a time:
- **Frozen history**: Fix history to the same value for all chunks → proves whether
history propagation is the source of drift/zoom
- **Single chunk**: Generate only 1 chunk → isolates per-chunk quality from
multi-chunk interaction bugs
- **Disable post-processing**: Remove cross-fade, blending, corrections → reveals
what the raw model output looks like
### Use reference on same hardware
Run the PyTorch reference on the same device (MPS for Apple Silicon). CUDA and MPS
produce different numerical results due to different float handling. Comparing your
MLX output against a CUDA reference adds noise to the comparison.
```python
# MPS may not support float64 — patch the reference:
original_linspace = torch.linspace
def patched_linspace(*args, **kwargs):
kwargs.pop("dtype", None)
return original_linspace(*args, dtype=torch.float32, **kwargs)
torch.linspace = patched_linspace
```
---
## 2. Precision & Dtype Pitfalls
### The #1 source of subtle bugs
Precision issues caused the most insidious bugs in our port. They don't cause
crashes — they cause progressive quality degradation that's hard to attribute.
### Residual connections MUST be float32
**Bug**: Progressive zoom/shrinking across autoregressive chunks.
**Root cause**: Residual additions (`x = x + attn_out`) in bfloat16. With 7-bit
mantissa, high-frequency spatial detail is systematically truncated. Over 144
residual ops × 6+ model calls per chunk, detail is progressively smoothed away.
**Fix**: Promote to float32 for the addition:
```python
# BAD — bfloat16 accumulation
x = x + attn_out
# GOOD — match reference's .float() pattern
x = (x.astype(mx.float32) + attn_out).astype(weight_dtype)
```
**Rule**: If the reference uses `.float()` anywhere, copy that pattern exactly. It's
there for a reason, even if a quick test seems to work without it.
### Scheduler computations need high precision
Diffusion schedulers involve:
- `x0 = xt - sigma * flow` — catastrophic cancellation near sigma ≈ 1
- `log(sigma)` and `exp()` — sensitive to small precision differences
Some references use float64 for these computations. MLX GPU doesn't support float64,
so use float32 and accept small numerical differences, but **never** use bfloat16
for scheduler math.
### Dtype propagation is invisible
Track dtype through your pipeline. A single bfloat16 intermediate can silently
downcast everything downstream:
```python
# This looks harmless but if model output is bfloat16:
result = noise - sigma * model_output # result is bfloat16!
# Fix: explicit cast
result = (noise.astype(mx.float32) - sigma * model_output.astype(mx.float32))
```
### Type promotion rules differ across frameworks
- PyTorch: bfloat16 + float32 → float32
- MLX: bfloat16 + float32 → float32 (same, but verify)
- NumPy: no bfloat16 support
Always check what your framework does and match the reference's implicit promotations.
### Float32 for VAE decoding
**Bug** (Wan2.2): VAE decode in bfloat16 produced visibly worse quality than reference.
Official Wan2.2 runs VAE decode in `torch.float` (float32), but our converted weights
were bfloat16. The VAE has many sequential layers where precision loss compounds.
**Fix**: Upcast VAE weights to float32 at load time. The VAE runs once per generation,
so the performance impact is negligible compared to the transformer.
### Modulation/gate vectors need float32
**Bug** (Wan2.2): Quality degradation from bfloat16 modulation across 30 blocks × 50 steps.
The official Wan2.2 explicitly uses `torch.amp.autocast('cuda', dtype=torch.float32)`
for time embeddings, modulation parameters, norm outputs before modulation, and gate ops.
**Fix**: Keep modulation in float32, cast to working dtype only when applying to the
hidden state:
```python
# Modulation computed in float32
e0 = self.modulation(time_emb) # float32
scale, shift, gate = e0.split(3, axis=-1)
# Cast to bfloat16 only for the matmul with hidden state
x = (x * (1 + scale.astype(x.dtype)) + shift.astype(x.dtype))
```
### Map PyTorch autocast zones precisely
PyTorch models use nested `torch.amp.autocast` scopes to switch precision. Map these
exactly:
- **Outer scope** (`bfloat16`): attention QKV projections, FFN matmuls
- **Inner scope** (`float32`): modulation, gates, norms, RoPE
- **Residual stream**: float32 (the "backbone" between blocks)
```python
# Wan2.2 dtype flow (matches official):
# Modulation/gates: float32 (explicit)
# QKV/FFN linear projections: bfloat16 (weight dtype)
# RoPE: float32 (official uses float64, MLX lacks float64)
# Attention Q/K: cast back to bfloat16 after RoPE
# Residual stream: float32
```
### Float32 promotion cascades kill performance
**Bug** (Wan2.2): ~2x slowdown from accidental float32 promotion.
A single float32 tensor (e.g., time embedding) flowing into bfloat16 operations
promotes the entire computation graph to float32. In Wan2.2:
- Time embedding MLP output (float32) fed into transformer → all layers float32
- RoPE frequencies (float32) applied to Q/K → all attention float32
**Fix**: Cast intermediate results to model dtype at promotion boundaries:
```python
# After time embedding MLP (float32), cast before feeding to transformer
time_emb = time_mlp(t).astype(model_dtype)
# After RoPE (float32), cast Q/K back to attention dtype
q = rope_apply(q, freqs).astype(v.dtype)
```
---
## 3. MLX-Specific Gotchas
### Underscore-prefixed attributes are invisible
**Bug** (Wan2.2): 87 of 110 VAE weights silently dropped during loading.
MLX's `nn.Module.parameters()` and `nn.Module.load_weights()` **skip** attributes
whose names start with underscore. If you name a layer `self._layer_0`, its weights
will never be loaded or saved.
```python
# BAD — weights silently ignored
self._layer_0 = nn.Linear(...) # nn.Module skips _prefixed attrs
# GOOD
self.layer_0 = nn.Linear(...)
```
This is especially insidious because there's no error — the model loads, runs, and
produces output. The output is just garbage because most weights are random.
### nn.Sequential indexing vs named children
PyTorch's `nn.Sequential` uses integer indices (`sequential.0.weight`), while MLX's
module hierarchy uses named attributes. When mirroring a PyTorch module structure,
you need explicit key sanitization:
```python
def sanitize_key(key):
# PyTorch: "decoder.middle.0.residual.1.weight"
# MLX: "decoder.middle.layer_0.residual.layer_1.weight"
key = re.sub(r'\.(\d+)', lambda m: f'.layer_{m.group(1)}', key)
return key
```
### Reshape axis ordering differs from PyTorch
**Bug** (Wan2.2): Green checkerboard pattern from VAE attention.
`[B,C,T,H,W]` cannot be directly reshaped to `[BT,C,H,W]` because in memory C
comes before T. PyTorch's `reshape` works because it handles non-contiguous tensors.
MLX requires explicit transpose first:
```python
# BAD — mixes channels with time
x = x.reshape(B*T, C, H, W) # Corrupts spatial layout
# GOOD — make B,T adjacent first
x = x.transpose(0, 2, 1, 3, 4) # [B,T,C,H,W]
x = x.reshape(B*T, C, H, W) # Now correct
```
### Patchify channel ordering
**Bug** (Wan2.2): Solid green video output from wrong patchify order.
When converting a Conv3d patchify to a manual reshape+linear, the dimension ordering
in the reshape must match the Conv3d weight layout. Conv3d expects `[C, pt, ph, pw]`
(channels slowest), but a naive reshape produces `[pt, ph, pw, C]` (channels fastest):
```python
# BAD — channel scrambling
patches = x.reshape(B, F', H', W', pt, ph, pw, C)
# GOOD — match Conv3d weight layout
patches = x.reshape(B, F', pt, H', ph, W', pw, C)
patches = patches.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, F', H', W', C, pt, ph, pw]
```
Verify numerically: the fixed version should match Conv3d output to ~1e-6.
### mx.zeros / padding inherits dtype
Use dtype-aware `mx.zeros` for padding and concatenation to avoid promotion:
```python
# BAD — default float32 padding promotes bfloat16 input
pad = mx.zeros((B, pad_len, C)) # float32!
x = mx.concatenate([pad, x], axis=1) # x promoted to float32
# GOOD — match input dtype
pad = mx.zeros((B, pad_len, C), dtype=x.dtype)
x = mx.concatenate([pad, x], axis=1) # stays bfloat16
```
### Use mx.fast kernels
Replace manual implementations with fused MLX kernels where possible:
```python
# Manual RMS norm → mx.fast.rms_norm
# Manual LayerNorm → mx.fast.layer_norm
# Manual attention → mx.fast.scaled_dot_product_attention
```
These are faster and handle precision internally.
---
## 4. Autoregressive Chunk Boundaries
For models that generate long videos by autoregressively extending chunks (Helios,
CogVideoX, etc.), chunk boundaries are the primary source of visual artifacts.
### Don't add post-processing the reference doesn't have
**Bug**: Added pixel cross-fade to smooth boundaries → caused 40% sharpness drop.
The reference pipeline used **no cross-fade at all**. The first frame of each new
chunk is intentionally a sharp reconstruction conditioned on history. Blending it with
the previous chunk's tail (which has different content) creates blur.
**Rule**: Before adding smoothing/blending, verify the reference doesn't do it.
Reference simplicity is usually correct.
### First-frame artifacts are common
The first pixel frame of each non-first chunk is typically a distorted reconstruction
of the conditioning frame. In many models, this is expected behavior:
- **Fix**: Drop the first frame from each chunk
- **Verify frame math**: If 33 raw frames at 16fps → drop 1 → 32 frames = exactly 2 seconds
### History conditioning errors compound
Small errors in how history is prepared, sliced, patchified, or position-encoded
will compound across chunks. The error is invisible in chunk 1, small in chunk 2,
and catastrophic by chunk 5.
**Debug strategy**: Generate with frozen history (same history for every chunk).
If the artifact disappears, the bug is in history handling.
---
## 5. VAE Decoder Artifacts
### Causal temporal convolutions cause boundary warmup
Video VAEs (WanVAE, CogVideoX-VAE) use causal temporal convolutions. When decoding
each chunk independently, the first few frames lack temporal context (only zero
padding), causing:
- **~7% contrast drop** in first frames of each chunk
- **Spatial brightness redistribution** (face darkens, background brightens)
This is inherent to the architecture. The reference has the same effect but at
lower magnitude.
### Post-processing to fix VAE warmup
Two-stage correction applied to first N frames of each non-first chunk:
```python
# Stage 1: Spatially-varying brightness correction
# Downsample reference (previous chunk's last frame) and current frame
ref_small = cv2.resize(ref_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
cur_small = cv2.resize(cur_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
diff_small = ref_small - cur_small
diff_full = cv2.resize(diff_small, (w, h), interpolation=cv2.INTER_LINEAR)
corrected = cur_frame + ramp * diff_full # ramp: 1.0 → 0.0 over N frames
# Stage 2: Per-channel contrast matching
for c in range(3):
ref_std = np.std(ref_frame[:,:,c])
cur_std = np.std(corrected[:,:,c])
scale = 1.0 + ramp * (ref_std / (cur_std + 1e-6) - 1.0)
corrected[:,:,c] = (corrected[:,:,c] - mean) * scale + mean
```
### VAE overlap decode does NOT work
**Attempted**: Prepend previous chunk's last latent frames to give the decoder
temporal context.
**Result**: Made things **worse** (22% contrast drop vs 7%). The causal convolutions
see conflicting content from different chunks and create larger artifacts than
zero-padding.
**Lesson**: Overlap only works when tiles contain the same content from the same
denoising process (e.g., spatial tiling). It fails for temporal chunks with
different content.
### Per-chunk VAE decoding is correct
Decode each chunk's latents independently, not concatenated. Concatenating all chunks
and decoding together lets boundary discontinuities propagate through temporal
convolutions, creating worse artifacts.
### First-frame quality: causal padding strategies
Multiple approaches were tried for the first-frame quality issue in Wan VAE:
| Approach | Result |
|----------|--------|
| Zero padding (default) | First ~4 frames degraded, but matches training |
| Replicate padding | Fixes artifacts but causes color intensity bias (conv applies all kernel weights to same value) |
| Warmup frame prepend | Helps motion but warmup frame itself has artifacts |
| Mirror-reflect warmup | Best compromise — varied context without zeros, no intensity bias |
**Lesson**: Don't assume "replicate padding is better than zero padding." The model
was trained with zero padding; changing it shifts the gain. Instead, prepend warmup
frames and trim them after decoding.
### RMS_norm vs L2 normalization
**Bug** (Wan2.2): Garbled output from incorrect normalization.
A PyTorch class named `RMS_norm` actually uses `F.normalize` (L2 norm: `x / ||x||_2`),
not RMS normalization (`x / sqrt(mean(x²))`). The difference is a factor of `sqrt(C)`,
causing values to explode through the decoder.
**Lesson**: Don't trust class names — read the actual implementation.
### Temporal frame count: causal boundary effects
**Bug** (Wan2.2): VAE produced 12 frames instead of 9 for a 9-frame input.
PyTorch reference processes frames one-by-one with caching, skipping temporal conv for
the first chunk. All-at-once decoding produces extra frames from zero-padded causal
context.
**Fix**: Use `first_chunk=True` flag to trim causal boundary frames, matching the
reference's chunked behavior.
### Chunked VAE encoding for I2V
**Bug** (Wan2.2 I2V-14B): Incorrect latents from non-chunked encoding.
Non-chunked encoding with causal zero-padding produces incorrect latents because
temporal features don't propagate correctly without caching. The reference uses chunked
encoding (1+4+4+... frames) with persistent temporal cache.
**Fix**: Implement chunked encoding with `feat_cache` propagation through CausalConv3d,
ResidualBlock, and Resample layers.
---
## 6. Scheduler & Timestep Issues
### Copy formulas exactly
Even small differences in scheduler formulas compound over many steps:
```python
# Dynamic time shifting — reference uses specific formula
mu = 0.5 + shift * 0.5 # NOT shift * 0.6 or any other constant
# Euler step
x_next = x + (sigma_next - sigma) * flow # order matters: next - current
```
### Verify sigma schedules numerically
Print and compare sigma values at each step:
```python
# Reference
sigmas_ref = [1.0, 0.99375, 0.9875, ...]
# Your implementation
sigmas = scheduler.get_sigmas(steps)
for i, (r, m) in enumerate(zip(sigmas_ref, sigmas)):
assert abs(r - m) < 1e-6, f"Step {i}: ref={r}, mlx={m}"
```
### Timestep embedding precision
Integer vs float timesteps matter. Some models expect `timestep=999` (int), others
expect `timestep=0.999` (float). Wrong type can silently produce wrong embeddings
with reasonable-looking but incorrect statistics.
### Boundary conditions: ±inf at sigma endpoints
**Bug** (Wan2.2): Greenish/yellow constant output from DPM++/UniPC schedulers.
The `lambda(sigma)` function must return `-inf` at `sigma=1.0` (pure noise) and `+inf`
at `sigma=0.0` (clean signal). Our implementation returned `0.0`, causing massive x0
overscaling on the first denoising step.
PyTorch naturally computes `torch.log(0) = -inf`, and `math.expm1(-inf) = -1.0`
handles the formulas correctly. Reproduce this behavior explicitly:
```python
def _lambda(self, sigma):
if sigma >= 1.0:
return float('-inf')
if sigma <= 0.0:
return float('inf')
return -math.log(sigma / (1 - sigma))
```
### UniPC corrector coefficients
**Bug** (Wan2.2): Accumulated artifacts across 47+ steps from wrong polynomial weights.
The UniPC corrector must compute `rhos_c` via `linalg.solve` for order ≥ 2. Hardcoded
`0.5` was 7× too large for the history weight (actual: ~0.08), causing massive
overweighting of history corrections.
---
## 7. Weight Conversion
### Always verify statistically
After converting weights from PyTorch to MLX format:
```python
for name in mlx_weights:
pt = pytorch_weights[map_name(name)]
mx_val = np.array(mlx_weights[name])
pt_val = pt.numpy()
cos_sim = np.dot(mx_val.flat, pt_val.flat) / (
np.linalg.norm(mx_val) * np.linalg.norm(pt_val) + 1e-10
)
if cos_sim < 0.9999:
print(f"MISMATCH: {name} cos_sim={cos_sim:.6f}")
```
### Conv3d → Linear reshaping
When converting 3D convolutions to linear layers (common for MLX which prefers
linear ops), the flattening order must match:
```python
# PyTorch Conv3d weight: (out_ch, in_ch, kT, kH, kW)
# Flatten to Linear: (out_ch, in_ch * kT * kH * kW)
# The reshape order MUST match how the input is patchified
```
### Sanitization functions
Write explicit weight sanitization that maps reference key names to your key names.
Don't rely on automatic matching — key naming conventions differ between frameworks.
### Module structure must mirror reference for direct loading
**Bug** (Wan2.2): Rewrote entire VAE module hierarchy to match PyTorch `nn.Sequential`
structure. ResidualBlock needed `None` gaps at specific indices to match the original
`nn.Sequential(RMSNorm, SiLU, Conv3d, ...)` indexing.
When possible, structure your modules to accept reference weights directly without
sanitization. This eliminates an entire class of bugs.
### Save VAE weights in float32
Even if the model uses bfloat16 for the transformer, save VAE weights in float32.
bfloat16 → float32 roundtrip loses precision that cannot be recovered by load-time
upcast.
### Temporal downsample/upsample order
**Bug** (Wan2.2): `temporal_downsample=[True, True, False]` but reference uses
`[False, True, True]`. Stage 0 created a `time_conv` with random weights (no matching
file key), and Stage 2 missed its `time_conv` (weights silently dropped).
Always verify boolean flags for each stage by inspecting the actual weight file keys.
### Silent weight drops are the worst bugs
When `load_weights()` with `strict=False` silently skips keys that don't match, you
get a model with random weights for those layers. This produces output that looks
"almost right" but is subtly wrong. Always log which keys were loaded vs skipped:
```python
loaded_keys = set()
for key, value in weights:
if key in model_params:
loaded_keys.add(key)
# Check for missing
expected = set(model_params.keys())
missing = expected - loaded_keys
if missing:
print(f"WARNING: {len(missing)} weights not loaded: {list(missing)[:5]}...")
```
---
## 8. Text Conditioning Failures
### Symptom: model predicts noise back to itself
If the model output correlates > 0.8 with its input noise, text conditioning is
likely broken. The model has learned nothing from the prompt and is just returning
its input.
### Check embedding statistics
```python
text_emb = text_encoder(prompt)
print(f"text_emb: mean={text_emb.mean():.4f} std={text_emb.std():.4f}")
# std < 0.01 → embeddings are collapsed → broken encoder or wrong weights
# std > 10.0 → embeddings are exploding → wrong normalization
```
### Verify with ablation
```python
# Generate with real text
output_text = denoise(latent, text_emb=real_embeddings)
# Generate with zeros
output_zero = denoise(latent, text_emb=mx.zeros_like(real_embeddings))
# Compare
text_influence = np.mean(np.abs(output_text - output_zero))
print(f"Text influence: {text_influence:.4f}") # Should be > 0 (typically 30-60% of output)
```
### Text preprocessing must match exactly
**Bug** (Wan2.2): Patchy-blurry output from wrong negative prompt tokenization.
The official Wan2.2 tokenizer applies `ftfy.fix_text` + `html.unescape` + whitespace
normalization before tokenization. Without this, fullwidth Chinese commas (U+FF0C)
tokenize differently from ASCII commas (U+002C), causing **27 different token IDs**
in the negative prompt. This made CFG's unconditional prediction wrong.
**Fix**: Apply the same text cleaning pipeline as the reference:
```python
import ftfy
import html
import re
def clean_text(text):
text = ftfy.fix_text(text)
text = html.unescape(text)
text = re.sub(r'\s+', ' ', text).strip()
return text
```
### T5 encoder precision
**Bug** (Wan2.2): Quality degradation from bfloat16 T5 attention.
T5 uses **no scaling** in attention (no `1/sqrt(d)` factor), so attention logits can
be very large. bfloat16 softmax loses significant precision across 24 encoder layers.
**Fix**: Compute T5 QK^T and softmax in float32. The T5 encoder only runs once per
generation, so the performance impact is negligible.
### Dual-model text embeddings
**Bug** (Wan2.2 I2V-14B): Low/high noise models have different `text_embedding` weights
(~42% relative difference). Using one model's embeddings for both caused incorrect
text conditioning for the high-noise model that handles critical early denoising steps.
**Fix**: Compute separate text embeddings for each model in dual-model setups.
---
## 9. Position Encodings (RoPE)
### Multi-scale consistency
In pyramid/multi-resolution models, RoPE must be computed consistently across scales.
If the model operates at 1/4 resolution in an early stage, the position grid must
reflect the actual spatial dimensions, not the final target dimensions.
### History vs current chunk
When conditioning on history from a previous chunk, the position encoding for
history frames must match what the model saw during training. Mismatches between
history and current-chunk position encodings can cause subtle spatial distortions
that compound across chunks.
### Factorized RoPE
3D video models often use factorized RoPE (separate temporal, height, width
frequencies). Verify each axis independently:
```python
# Compare temporal frequencies
assert np.allclose(mlx_rope_t, ref_rope_t, atol=1e-5)
# Compare spatial frequencies
assert np.allclose(mlx_rope_h, ref_rope_h, atol=1e-5)
assert np.allclose(mlx_rope_w, ref_rope_w, atol=1e-5)
```
### Per-axis frequency construction
**Bug** (Wan2.2): Grey/artifact-filled output from wrong frequency distribution.
The reference uses three separate `rope_params()` calls with different dimension
normalizations (e.g., 44, 42, 42 for Wan) so each axis gets its own full frequency
range. Consolidating into a single `rope_params(head_dim)` call and splitting gave
height frequencies starting at 0.042 and width at 0.002 (should be 1.0 for both).
**Fix** (and subsequent revert): This bug was introduced as a "fix" for a previous
RoPE issue, then had to be reverted. The lesson: RoPE changes have far-reaching effects.
Always verify with actual generation, not just numerical comparison of frequencies.
**Lesson**: Read the reference's frequency construction very carefully. Don't
"simplify" three separate calls into one unless you verify the frequency distribution
matches exactly.
---
## 10. Multi-Stage / Pyramid Pipelines
### Each stage is a potential failure point
Pyramid pipelines (generate at low res, upsample, refine at high res) multiply the
number of things that can go wrong:
- Downsampling method (bilinear vs area) must match reference
- Energy compensation factors (e.g., ×2 after bilinear downsample) must be present
- Alpha/beta noise mixing coefficients are stage-dependent
- Frame indices and history resolution change per stage
### Test single-stage first
If the model works at full resolution for a single stage but fails in the pyramid,
the bug is in stage orchestration — typically in how latents are passed between
stages or how position encodings adapt to different resolutions.
### Integration bugs are the hardest
We verified every Helios component matched the reference individually, but the
pyramid still produced uniform color. The bug was in dtype handling during stage
transitions. Integration bugs only appear when components interact.
---
## 11. Common Symptoms → Root Causes
| Symptom | Likely Root Causes |
|---------|-------------------|
| **Pure noise output** | Wrong sigma schedule, broken text conditioning, incorrect weight mapping |
| **Uniform color** | Model predicting noise back; text embeddings collapsed; wrong timestep format |
| **Progressive zoom/shrink** | bfloat16 residuals truncating high-freq detail; RoPE mismatch across chunks |
| **Brightness jumps at boundaries** | VAE causal warmup; cross-fade blending misaligned content |
| **Color drift across chunks** | Dtype in scheduler step; history normalization missing |
| **Blur at boundaries** | Cross-fade enabled; latent blending; wrong VAE decode order |
| **Grid/checker patterns** | Patchify channel ordering bug; latent blend artifacts; reshape axis error |
| **Green/magenta tint** | VAE weight key mismatch; wrong denormalization constants; cv2 YUV color matrix |
| **Mean drift across steps** | bfloat16 accumulation; wrong scheduler formula; missing energy compensation |
| **Garbled/scrambled output** | Silent weight drops (underscore prefix, wrong key mapping); RMS vs L2 norm |
| **Greenish-yellow constant** | Scheduler boundary condition (log(0) not returning -inf); x0 overscaling |
| **~2x slower than expected** | Float32 promotion cascade from single mistyped intermediate |
| **Extra output frames** | Causal padding producing extra temporal frames; missing `first_chunk` trim |
| **Grey/artifact output** | RoPE frequency construction wrong (per-axis vs single-call) |
| **Patchy-blurry with CFG** | Text preprocessing mismatch (fullwidth vs ASCII chars → wrong tokenization) |
| **I2V temporal mismatch** | Non-chunked VAE encoding vs reference's chunked encoding with temporal cache |
---
## 12. Verification Checklist
Use this checklist when porting a new diffusion video model:
### Model
- [ ] Weight conversion: all keys mapped, cosine similarity > 0.9999
- [ ] No silent weight drops (log loaded vs expected keys)
- [ ] Single forward pass matches reference (cos_sim > 0.999)
- [ ] Residual connections use float32 accumulation
- [ ] Attention computation matches reference precision
- [ ] Modulation/gate vectors in float32 (if reference uses autocast)
- [ ] No underscore-prefixed module attributes (MLX ignores them)
### Scheduler
- [ ] Sigma values match reference at every step (diff < 1e-6)
- [ ] Timestep format correct (int vs float, scale factor)
- [ ] Dynamic shifting formula copied exactly
- [ ] Step function returns correct dtype (float32)
- [ ] Boundary conditions: lambda(-inf) at sigma=1, lambda(+inf) at sigma=0
- [ ] Higher-order coefficients computed (not hardcoded) for UniPC/DPM++
### Text Encoder
- [ ] Embedding statistics reasonable (0.01 < std < 10)
- [ ] Text influence > 0 (ablation test)
- [ ] Tokenization matches (special tokens, padding, max length)
- [ ] Text preprocessing matches (ftfy, html unescape, whitespace normalization)
- [ ] T5/CLIP attention precision (float32 softmax if no 1/sqrt(d) scaling)
- [ ] Separate embeddings for dual-model setups (if applicable)
### VAE
- [ ] Denormalization constants match training pipeline
- [ ] Per-chunk decoding (not concatenated)
- [ ] Temporal frame count correct (account for causal padding)
- [ ] Weight keys mapped correctly (encoder vs decoder)
- [ ] Weights stored/loaded in float32 (not bfloat16)
- [ ] Temporal downsample/upsample order matches reference
- [ ] RMS_norm vs L2_norm: check actual implementation, not class name
- [ ] Chunked encoding for I2V (if applicable)
- [ ] Reshape axis ordering correct ([B,C,T,H,W] → transpose before reshape)
### Pipeline Orchestration
- [ ] Position encodings consistent across stages/chunks
- [ ] History slicing and conditioning correct
- [ ] Noise generation matches (distribution, correlation structure)
- [ ] Multi-chunk output visually consistent (no progressive degradation)
- [ ] Dimension auto-alignment (divisible by patch_size × vae_stride)
- [ ] Dtype-aware padding (mx.zeros with explicit dtype)
### Output
- [ ] Frame count matches expected (account for warmup/trim)
- [ ] FPS correct
- [ ] Color range [0, 255] uint8 for video
- [ ] No first-frame duplication artifacts
- [ ] Video codec correct (imageio/libx264 preferred over cv2/mp4v on macOS)
### Performance
- [ ] No float32 promotion cascades (check with profiler)
- [ ] Using mx.fast kernels (rms_norm, layer_norm, sdpa)
- [ ] Time embedding computed once per sample (not per position)
- [ ] Memory cleanup (delete temporaries before mx.eval)
---
## 13. Diagnostic Tools
### General video diagnostics (`scripts/video/`)
| Script | Purpose |
|--------|---------|
| `compare_videos.py` | PSNR, SSIM, temporal coherence, color fidelity between two videos |
| `video_quality.py` | Sharpness, stability, defect detection, chunk boundary analysis |
```bash
# Quick quality check
python scripts/video/video_quality.py output.mp4 --chunk-size 32
# Compare against reference
python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4
```
### Model-specific diagnostics (`scripts/helios/`)
| Script | Purpose |
|--------|---------|
| `analyze_boundaries.py` | Detailed boundary quality metrics for Helios |
| `run_reference.py` | Run PyTorch reference on MPS |
| `compare_pipelines.py` | Compare scheduler/pipeline mechanics |
| `compare_models.py` | Cross-framework model output comparison |
### Inline debugging pattern
Add temporary debug output to the diffusion loop:
```python
for i, sigma in enumerate(sigmas):
flow = model(latent, sigma, text_emb)
latent = scheduler.step(latent, flow, sigma, sigma_next)
# Debug: track statistics
print(f"[step {i}] sigma={sigma:.4f} "
f"latent: mean={latent.mean():.6f} std={latent.std():.6f} "
f"flow: mean={flow.mean():.6f} std={flow.std():.6f}")
# Debug: save for cross-framework comparison
if os.environ.get("DEBUG"):
mx.save(f"/tmp/debug_step_{i}.npz", {
"latent": latent, "flow": flow, "sigma": mx.array(sigma)
})
```
---
## Key Takeaways
1. **Precision is the #1 bug source** — bfloat16 residuals, scheduler math, type
promotion, modulation vectors. Copy the reference's `.float()` and `autocast` zones.
2. **Don't add what the reference doesn't have** — cross-fade, overlap decode,
temporal blending. If the reference works without it, you probably have a bug
elsewhere.
3. **Silent failures are the hardest bugs** — underscore-prefixed weights, `strict=False`
weight loading, wrong normalization class names. Always verify weight load counts
and output statistics.
4. **Component isolation → integration testing** — verify each part matches, then
debug their interaction.
5. **Statistical comparison beats visual inspection** — mean drift, contrast ratios,
and cosine similarity catch bugs before they're visible.
6. **Autoregressive errors compound** — a 1% error per chunk becomes 10% by chunk 10.
Fix precision first, add corrections second.
7. **MLX has unique pitfalls** — underscore attribute names, reshape axis ordering,
dtype-unaware padding, and float32 promotion cascades. Know your framework.
8. **Text preprocessing matters** — Unicode normalization, fullwidth chars, HTML entities.
A single mismatched comma can break CFG guidance.
9. **VAE is deceptively complex** — causal padding, temporal frame counts, chunked vs
batch processing, norm implementations. Budget significant debugging time for VAE.