feat(wan): Add diagnostic scripts and porting guide
This commit is contained in:
911
docs/PORTING-GUIDE.md
Normal file
911
docs/PORTING-GUIDE.md
Normal file
@@ -0,0 +1,911 @@
|
||||
# Porting Diffusion Video Models to MLX: Lessons Learned
|
||||
|
||||
A practical guide distilled from porting Wan2.1/2.2 (1.3B–14B) and Helios 14B DiT
|
||||
video generation models from PyTorch to MLX on Apple Silicon. These lessons apply
|
||||
broadly to any diffusion-based video (or image) model port.
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Debugging Methodology](#1-debugging-methodology)
|
||||
2. [Precision & Dtype Pitfalls](#2-precision--dtype-pitfalls)
|
||||
3. [MLX-Specific Gotchas](#3-mlx-specific-gotchas)
|
||||
4. [Autoregressive Chunk Boundaries](#4-autoregressive-chunk-boundaries)
|
||||
5. [VAE Decoder Artifacts](#5-vae-decoder-artifacts)
|
||||
6. [Scheduler & Timestep Issues](#6-scheduler--timestep-issues)
|
||||
7. [Weight Conversion](#7-weight-conversion)
|
||||
8. [Text Conditioning Failures](#8-text-conditioning-failures)
|
||||
9. [Position Encodings (RoPE)](#9-position-encodings-rope)
|
||||
10. [Multi-Stage / Pyramid Pipelines](#10-multi-stage--pyramid-pipelines)
|
||||
11. [Common Symptoms → Root Causes](#11-common-symptoms--root-causes)
|
||||
12. [Verification Checklist](#12-verification-checklist)
|
||||
13. [Diagnostic Tools](#13-diagnostic-tools)
|
||||
|
||||
---
|
||||
|
||||
## 1. Debugging Methodology
|
||||
|
||||
### Component isolation first
|
||||
|
||||
Never debug the full pipeline. Test each component in isolation:
|
||||
|
||||
1. **Text encoder** — Does it produce embeddings with reasonable statistics? (std > 0.01)
|
||||
2. **Scheduler** — Do sigma/timestep values match the reference exactly?
|
||||
3. **Transformer** — Does a single forward pass match the reference? (cosine similarity > 0.999)
|
||||
4. **VAE decoder** — Feed reference latents into your VAE. Does the output look correct?
|
||||
|
||||
If every component matches individually but the pipeline fails, the bug is in
|
||||
**orchestration** — how components are wired together.
|
||||
|
||||
### Statistical fingerprinting
|
||||
|
||||
Track per-step statistics through the diffusion loop:
|
||||
|
||||
```python
|
||||
# After each denoising step
|
||||
print(f"step {i}: mean={latent.mean():.6f} std={latent.std():.6f} "
|
||||
f"min={latent.min():.4f} max={latent.max():.4f}")
|
||||
```
|
||||
|
||||
**What to look for:**
|
||||
- **Progressive mean drift** (e.g., -0.002 → -0.040 → -0.123) signals accumulating errors
|
||||
- **Collapsing std** (std dropping toward 0) signals broken conditioning or wrong noise schedule
|
||||
- **Exploding values** signal wrong sigma scaling or scheduler formula
|
||||
|
||||
### Cross-framework numerical comparison
|
||||
|
||||
The most powerful debugging tool: save intermediate tensors from your MLX pipeline,
|
||||
feed them to the PyTorch reference, compare outputs.
|
||||
|
||||
```python
|
||||
# In MLX pipeline, save inputs before transformer call
|
||||
mx.save("debug_inputs.npz", {"latent": latent, "timestep": t, "text_emb": text_emb})
|
||||
|
||||
# In PyTorch script, load and compare
|
||||
inputs = np.load("debug_inputs.npz")
|
||||
mlx_out = np.load("debug_output.npz")["flow"]
|
||||
pt_out = reference_model(torch.from_numpy(inputs["latent"]), ...)
|
||||
cos_sim = F.cosine_similarity(pt_out.flatten(), torch.from_numpy(mlx_out).flatten(), dim=0)
|
||||
# cos_sim > 0.999 = model is correct; bug is elsewhere
|
||||
# cos_sim < 0.99 = model has a bug; compare per-layer
|
||||
```
|
||||
|
||||
### Ablation testing
|
||||
|
||||
When a pipeline has multiple "fixes" or features, disable them one at a time:
|
||||
|
||||
- **Frozen history**: Fix history to the same value for all chunks → proves whether
|
||||
history propagation is the source of drift/zoom
|
||||
- **Single chunk**: Generate only 1 chunk → isolates per-chunk quality from
|
||||
multi-chunk interaction bugs
|
||||
- **Disable post-processing**: Remove cross-fade, blending, corrections → reveals
|
||||
what the raw model output looks like
|
||||
|
||||
### Use reference on same hardware
|
||||
|
||||
Run the PyTorch reference on the same device (MPS for Apple Silicon). CUDA and MPS
|
||||
produce different numerical results due to different float handling. Comparing your
|
||||
MLX output against a CUDA reference adds noise to the comparison.
|
||||
|
||||
```python
|
||||
# MPS may not support float64 — patch the reference:
|
||||
original_linspace = torch.linspace
|
||||
def patched_linspace(*args, **kwargs):
|
||||
kwargs.pop("dtype", None)
|
||||
return original_linspace(*args, dtype=torch.float32, **kwargs)
|
||||
torch.linspace = patched_linspace
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. Precision & Dtype Pitfalls
|
||||
|
||||
### The #1 source of subtle bugs
|
||||
|
||||
Precision issues caused the most insidious bugs in our port. They don't cause
|
||||
crashes — they cause progressive quality degradation that's hard to attribute.
|
||||
|
||||
### Residual connections MUST be float32
|
||||
|
||||
**Bug**: Progressive zoom/shrinking across autoregressive chunks.
|
||||
|
||||
**Root cause**: Residual additions (`x = x + attn_out`) in bfloat16. With 7-bit
|
||||
mantissa, high-frequency spatial detail is systematically truncated. Over 144
|
||||
residual ops × 6+ model calls per chunk, detail is progressively smoothed away.
|
||||
|
||||
**Fix**: Promote to float32 for the addition:
|
||||
```python
|
||||
# BAD — bfloat16 accumulation
|
||||
x = x + attn_out
|
||||
|
||||
# GOOD — match reference's .float() pattern
|
||||
x = (x.astype(mx.float32) + attn_out).astype(weight_dtype)
|
||||
```
|
||||
|
||||
**Rule**: If the reference uses `.float()` anywhere, copy that pattern exactly. It's
|
||||
there for a reason, even if a quick test seems to work without it.
|
||||
|
||||
### Scheduler computations need high precision
|
||||
|
||||
Diffusion schedulers involve:
|
||||
- `x0 = xt - sigma * flow` — catastrophic cancellation near sigma ≈ 1
|
||||
- `log(sigma)` and `exp()` — sensitive to small precision differences
|
||||
|
||||
Some references use float64 for these computations. MLX GPU doesn't support float64,
|
||||
so use float32 and accept small numerical differences, but **never** use bfloat16
|
||||
for scheduler math.
|
||||
|
||||
### Dtype propagation is invisible
|
||||
|
||||
Track dtype through your pipeline. A single bfloat16 intermediate can silently
|
||||
downcast everything downstream:
|
||||
|
||||
```python
|
||||
# This looks harmless but if model output is bfloat16:
|
||||
result = noise - sigma * model_output # result is bfloat16!
|
||||
|
||||
# Fix: explicit cast
|
||||
result = (noise.astype(mx.float32) - sigma * model_output.astype(mx.float32))
|
||||
```
|
||||
|
||||
### Type promotion rules differ across frameworks
|
||||
|
||||
- PyTorch: bfloat16 + float32 → float32
|
||||
- MLX: bfloat16 + float32 → float32 (same, but verify)
|
||||
- NumPy: no bfloat16 support
|
||||
|
||||
Always check what your framework does and match the reference's implicit promotations.
|
||||
|
||||
### Float32 for VAE decoding
|
||||
|
||||
**Bug** (Wan2.2): VAE decode in bfloat16 produced visibly worse quality than reference.
|
||||
|
||||
Official Wan2.2 runs VAE decode in `torch.float` (float32), but our converted weights
|
||||
were bfloat16. The VAE has many sequential layers where precision loss compounds.
|
||||
|
||||
**Fix**: Upcast VAE weights to float32 at load time. The VAE runs once per generation,
|
||||
so the performance impact is negligible compared to the transformer.
|
||||
|
||||
### Modulation/gate vectors need float32
|
||||
|
||||
**Bug** (Wan2.2): Quality degradation from bfloat16 modulation across 30 blocks × 50 steps.
|
||||
|
||||
The official Wan2.2 explicitly uses `torch.amp.autocast('cuda', dtype=torch.float32)`
|
||||
for time embeddings, modulation parameters, norm outputs before modulation, and gate ops.
|
||||
|
||||
**Fix**: Keep modulation in float32, cast to working dtype only when applying to the
|
||||
hidden state:
|
||||
```python
|
||||
# Modulation computed in float32
|
||||
e0 = self.modulation(time_emb) # float32
|
||||
scale, shift, gate = e0.split(3, axis=-1)
|
||||
|
||||
# Cast to bfloat16 only for the matmul with hidden state
|
||||
x = (x * (1 + scale.astype(x.dtype)) + shift.astype(x.dtype))
|
||||
```
|
||||
|
||||
### Map PyTorch autocast zones precisely
|
||||
|
||||
PyTorch models use nested `torch.amp.autocast` scopes to switch precision. Map these
|
||||
exactly:
|
||||
- **Outer scope** (`bfloat16`): attention QKV projections, FFN matmuls
|
||||
- **Inner scope** (`float32`): modulation, gates, norms, RoPE
|
||||
- **Residual stream**: float32 (the "backbone" between blocks)
|
||||
|
||||
```python
|
||||
# Wan2.2 dtype flow (matches official):
|
||||
# Modulation/gates: float32 (explicit)
|
||||
# QKV/FFN linear projections: bfloat16 (weight dtype)
|
||||
# RoPE: float32 (official uses float64, MLX lacks float64)
|
||||
# Attention Q/K: cast back to bfloat16 after RoPE
|
||||
# Residual stream: float32
|
||||
```
|
||||
|
||||
### Float32 promotion cascades kill performance
|
||||
|
||||
**Bug** (Wan2.2): ~2x slowdown from accidental float32 promotion.
|
||||
|
||||
A single float32 tensor (e.g., time embedding) flowing into bfloat16 operations
|
||||
promotes the entire computation graph to float32. In Wan2.2:
|
||||
- Time embedding MLP output (float32) fed into transformer → all layers float32
|
||||
- RoPE frequencies (float32) applied to Q/K → all attention float32
|
||||
|
||||
**Fix**: Cast intermediate results to model dtype at promotion boundaries:
|
||||
```python
|
||||
# After time embedding MLP (float32), cast before feeding to transformer
|
||||
time_emb = time_mlp(t).astype(model_dtype)
|
||||
|
||||
# After RoPE (float32), cast Q/K back to attention dtype
|
||||
q = rope_apply(q, freqs).astype(v.dtype)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. MLX-Specific Gotchas
|
||||
|
||||
### Underscore-prefixed attributes are invisible
|
||||
|
||||
**Bug** (Wan2.2): 87 of 110 VAE weights silently dropped during loading.
|
||||
|
||||
MLX's `nn.Module.parameters()` and `nn.Module.load_weights()` **skip** attributes
|
||||
whose names start with underscore. If you name a layer `self._layer_0`, its weights
|
||||
will never be loaded or saved.
|
||||
|
||||
```python
|
||||
# BAD — weights silently ignored
|
||||
self._layer_0 = nn.Linear(...) # nn.Module skips _prefixed attrs
|
||||
|
||||
# GOOD
|
||||
self.layer_0 = nn.Linear(...)
|
||||
```
|
||||
|
||||
This is especially insidious because there's no error — the model loads, runs, and
|
||||
produces output. The output is just garbage because most weights are random.
|
||||
|
||||
### nn.Sequential indexing vs named children
|
||||
|
||||
PyTorch's `nn.Sequential` uses integer indices (`sequential.0.weight`), while MLX's
|
||||
module hierarchy uses named attributes. When mirroring a PyTorch module structure,
|
||||
you need explicit key sanitization:
|
||||
|
||||
```python
|
||||
def sanitize_key(key):
|
||||
# PyTorch: "decoder.middle.0.residual.1.weight"
|
||||
# MLX: "decoder.middle.layer_0.residual.layer_1.weight"
|
||||
key = re.sub(r'\.(\d+)', lambda m: f'.layer_{m.group(1)}', key)
|
||||
return key
|
||||
```
|
||||
|
||||
### Reshape axis ordering differs from PyTorch
|
||||
|
||||
**Bug** (Wan2.2): Green checkerboard pattern from VAE attention.
|
||||
|
||||
`[B,C,T,H,W]` cannot be directly reshaped to `[BT,C,H,W]` because in memory C
|
||||
comes before T. PyTorch's `reshape` works because it handles non-contiguous tensors.
|
||||
MLX requires explicit transpose first:
|
||||
|
||||
```python
|
||||
# BAD — mixes channels with time
|
||||
x = x.reshape(B*T, C, H, W) # Corrupts spatial layout
|
||||
|
||||
# GOOD — make B,T adjacent first
|
||||
x = x.transpose(0, 2, 1, 3, 4) # [B,T,C,H,W]
|
||||
x = x.reshape(B*T, C, H, W) # Now correct
|
||||
```
|
||||
|
||||
### Patchify channel ordering
|
||||
|
||||
**Bug** (Wan2.2): Solid green video output from wrong patchify order.
|
||||
|
||||
When converting a Conv3d patchify to a manual reshape+linear, the dimension ordering
|
||||
in the reshape must match the Conv3d weight layout. Conv3d expects `[C, pt, ph, pw]`
|
||||
(channels slowest), but a naive reshape produces `[pt, ph, pw, C]` (channels fastest):
|
||||
|
||||
```python
|
||||
# BAD — channel scrambling
|
||||
patches = x.reshape(B, F', H', W', pt, ph, pw, C)
|
||||
|
||||
# GOOD — match Conv3d weight layout
|
||||
patches = x.reshape(B, F', pt, H', ph, W', pw, C)
|
||||
patches = patches.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, F', H', W', C, pt, ph, pw]
|
||||
```
|
||||
|
||||
Verify numerically: the fixed version should match Conv3d output to ~1e-6.
|
||||
|
||||
### mx.zeros / padding inherits dtype
|
||||
|
||||
Use dtype-aware `mx.zeros` for padding and concatenation to avoid promotion:
|
||||
|
||||
```python
|
||||
# BAD — default float32 padding promotes bfloat16 input
|
||||
pad = mx.zeros((B, pad_len, C)) # float32!
|
||||
x = mx.concatenate([pad, x], axis=1) # x promoted to float32
|
||||
|
||||
# GOOD — match input dtype
|
||||
pad = mx.zeros((B, pad_len, C), dtype=x.dtype)
|
||||
x = mx.concatenate([pad, x], axis=1) # stays bfloat16
|
||||
```
|
||||
|
||||
### Use mx.fast kernels
|
||||
|
||||
Replace manual implementations with fused MLX kernels where possible:
|
||||
|
||||
```python
|
||||
# Manual RMS norm → mx.fast.rms_norm
|
||||
# Manual LayerNorm → mx.fast.layer_norm
|
||||
# Manual attention → mx.fast.scaled_dot_product_attention
|
||||
```
|
||||
|
||||
These are faster and handle precision internally.
|
||||
|
||||
---
|
||||
|
||||
## 4. Autoregressive Chunk Boundaries
|
||||
|
||||
For models that generate long videos by autoregressively extending chunks (Helios,
|
||||
CogVideoX, etc.), chunk boundaries are the primary source of visual artifacts.
|
||||
|
||||
### Don't add post-processing the reference doesn't have
|
||||
|
||||
**Bug**: Added pixel cross-fade to smooth boundaries → caused 40% sharpness drop.
|
||||
|
||||
The reference pipeline used **no cross-fade at all**. The first frame of each new
|
||||
chunk is intentionally a sharp reconstruction conditioned on history. Blending it with
|
||||
the previous chunk's tail (which has different content) creates blur.
|
||||
|
||||
**Rule**: Before adding smoothing/blending, verify the reference doesn't do it.
|
||||
Reference simplicity is usually correct.
|
||||
|
||||
### First-frame artifacts are common
|
||||
|
||||
The first pixel frame of each non-first chunk is typically a distorted reconstruction
|
||||
of the conditioning frame. In many models, this is expected behavior:
|
||||
|
||||
- **Fix**: Drop the first frame from each chunk
|
||||
- **Verify frame math**: If 33 raw frames at 16fps → drop 1 → 32 frames = exactly 2 seconds
|
||||
|
||||
### History conditioning errors compound
|
||||
|
||||
Small errors in how history is prepared, sliced, patchified, or position-encoded
|
||||
will compound across chunks. The error is invisible in chunk 1, small in chunk 2,
|
||||
and catastrophic by chunk 5.
|
||||
|
||||
**Debug strategy**: Generate with frozen history (same history for every chunk).
|
||||
If the artifact disappears, the bug is in history handling.
|
||||
|
||||
---
|
||||
|
||||
## 5. VAE Decoder Artifacts
|
||||
|
||||
### Causal temporal convolutions cause boundary warmup
|
||||
|
||||
Video VAEs (WanVAE, CogVideoX-VAE) use causal temporal convolutions. When decoding
|
||||
each chunk independently, the first few frames lack temporal context (only zero
|
||||
padding), causing:
|
||||
|
||||
- **~7% contrast drop** in first frames of each chunk
|
||||
- **Spatial brightness redistribution** (face darkens, background brightens)
|
||||
|
||||
This is inherent to the architecture. The reference has the same effect but at
|
||||
lower magnitude.
|
||||
|
||||
### Post-processing to fix VAE warmup
|
||||
|
||||
Two-stage correction applied to first N frames of each non-first chunk:
|
||||
|
||||
```python
|
||||
# Stage 1: Spatially-varying brightness correction
|
||||
# Downsample reference (previous chunk's last frame) and current frame
|
||||
ref_small = cv2.resize(ref_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
|
||||
cur_small = cv2.resize(cur_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
|
||||
diff_small = ref_small - cur_small
|
||||
diff_full = cv2.resize(diff_small, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
corrected = cur_frame + ramp * diff_full # ramp: 1.0 → 0.0 over N frames
|
||||
|
||||
# Stage 2: Per-channel contrast matching
|
||||
for c in range(3):
|
||||
ref_std = np.std(ref_frame[:,:,c])
|
||||
cur_std = np.std(corrected[:,:,c])
|
||||
scale = 1.0 + ramp * (ref_std / (cur_std + 1e-6) - 1.0)
|
||||
corrected[:,:,c] = (corrected[:,:,c] - mean) * scale + mean
|
||||
```
|
||||
|
||||
### VAE overlap decode does NOT work
|
||||
|
||||
**Attempted**: Prepend previous chunk's last latent frames to give the decoder
|
||||
temporal context.
|
||||
|
||||
**Result**: Made things **worse** (22% contrast drop vs 7%). The causal convolutions
|
||||
see conflicting content from different chunks and create larger artifacts than
|
||||
zero-padding.
|
||||
|
||||
**Lesson**: Overlap only works when tiles contain the same content from the same
|
||||
denoising process (e.g., spatial tiling). It fails for temporal chunks with
|
||||
different content.
|
||||
|
||||
### Per-chunk VAE decoding is correct
|
||||
|
||||
Decode each chunk's latents independently, not concatenated. Concatenating all chunks
|
||||
and decoding together lets boundary discontinuities propagate through temporal
|
||||
convolutions, creating worse artifacts.
|
||||
|
||||
### First-frame quality: causal padding strategies
|
||||
|
||||
Multiple approaches were tried for the first-frame quality issue in Wan VAE:
|
||||
|
||||
| Approach | Result |
|
||||
|----------|--------|
|
||||
| Zero padding (default) | First ~4 frames degraded, but matches training |
|
||||
| Replicate padding | Fixes artifacts but causes color intensity bias (conv applies all kernel weights to same value) |
|
||||
| Warmup frame prepend | Helps motion but warmup frame itself has artifacts |
|
||||
| Mirror-reflect warmup | Best compromise — varied context without zeros, no intensity bias |
|
||||
|
||||
**Lesson**: Don't assume "replicate padding is better than zero padding." The model
|
||||
was trained with zero padding; changing it shifts the gain. Instead, prepend warmup
|
||||
frames and trim them after decoding.
|
||||
|
||||
### RMS_norm vs L2 normalization
|
||||
|
||||
**Bug** (Wan2.2): Garbled output from incorrect normalization.
|
||||
|
||||
A PyTorch class named `RMS_norm` actually uses `F.normalize` (L2 norm: `x / ||x||_2`),
|
||||
not RMS normalization (`x / sqrt(mean(x²))`). The difference is a factor of `sqrt(C)`,
|
||||
causing values to explode through the decoder.
|
||||
|
||||
**Lesson**: Don't trust class names — read the actual implementation.
|
||||
|
||||
### Temporal frame count: causal boundary effects
|
||||
|
||||
**Bug** (Wan2.2): VAE produced 12 frames instead of 9 for a 9-frame input.
|
||||
|
||||
PyTorch reference processes frames one-by-one with caching, skipping temporal conv for
|
||||
the first chunk. All-at-once decoding produces extra frames from zero-padded causal
|
||||
context.
|
||||
|
||||
**Fix**: Use `first_chunk=True` flag to trim causal boundary frames, matching the
|
||||
reference's chunked behavior.
|
||||
|
||||
### Chunked VAE encoding for I2V
|
||||
|
||||
**Bug** (Wan2.2 I2V-14B): Incorrect latents from non-chunked encoding.
|
||||
|
||||
Non-chunked encoding with causal zero-padding produces incorrect latents because
|
||||
temporal features don't propagate correctly without caching. The reference uses chunked
|
||||
encoding (1+4+4+... frames) with persistent temporal cache.
|
||||
|
||||
**Fix**: Implement chunked encoding with `feat_cache` propagation through CausalConv3d,
|
||||
ResidualBlock, and Resample layers.
|
||||
|
||||
---
|
||||
|
||||
## 6. Scheduler & Timestep Issues
|
||||
|
||||
### Copy formulas exactly
|
||||
|
||||
Even small differences in scheduler formulas compound over many steps:
|
||||
|
||||
```python
|
||||
# Dynamic time shifting — reference uses specific formula
|
||||
mu = 0.5 + shift * 0.5 # NOT shift * 0.6 or any other constant
|
||||
|
||||
# Euler step
|
||||
x_next = x + (sigma_next - sigma) * flow # order matters: next - current
|
||||
```
|
||||
|
||||
### Verify sigma schedules numerically
|
||||
|
||||
Print and compare sigma values at each step:
|
||||
|
||||
```python
|
||||
# Reference
|
||||
sigmas_ref = [1.0, 0.99375, 0.9875, ...]
|
||||
|
||||
# Your implementation
|
||||
sigmas = scheduler.get_sigmas(steps)
|
||||
for i, (r, m) in enumerate(zip(sigmas_ref, sigmas)):
|
||||
assert abs(r - m) < 1e-6, f"Step {i}: ref={r}, mlx={m}"
|
||||
```
|
||||
|
||||
### Timestep embedding precision
|
||||
|
||||
Integer vs float timesteps matter. Some models expect `timestep=999` (int), others
|
||||
expect `timestep=0.999` (float). Wrong type can silently produce wrong embeddings
|
||||
with reasonable-looking but incorrect statistics.
|
||||
|
||||
### Boundary conditions: ±inf at sigma endpoints
|
||||
|
||||
**Bug** (Wan2.2): Greenish/yellow constant output from DPM++/UniPC schedulers.
|
||||
|
||||
The `lambda(sigma)` function must return `-inf` at `sigma=1.0` (pure noise) and `+inf`
|
||||
at `sigma=0.0` (clean signal). Our implementation returned `0.0`, causing massive x0
|
||||
overscaling on the first denoising step.
|
||||
|
||||
PyTorch naturally computes `torch.log(0) = -inf`, and `math.expm1(-inf) = -1.0`
|
||||
handles the formulas correctly. Reproduce this behavior explicitly:
|
||||
|
||||
```python
|
||||
def _lambda(self, sigma):
|
||||
if sigma >= 1.0:
|
||||
return float('-inf')
|
||||
if sigma <= 0.0:
|
||||
return float('inf')
|
||||
return -math.log(sigma / (1 - sigma))
|
||||
```
|
||||
|
||||
### UniPC corrector coefficients
|
||||
|
||||
**Bug** (Wan2.2): Accumulated artifacts across 47+ steps from wrong polynomial weights.
|
||||
|
||||
The UniPC corrector must compute `rhos_c` via `linalg.solve` for order ≥ 2. Hardcoded
|
||||
`0.5` was 7× too large for the history weight (actual: ~0.08), causing massive
|
||||
overweighting of history corrections.
|
||||
|
||||
---
|
||||
|
||||
## 7. Weight Conversion
|
||||
|
||||
### Always verify statistically
|
||||
|
||||
After converting weights from PyTorch to MLX format:
|
||||
|
||||
```python
|
||||
for name in mlx_weights:
|
||||
pt = pytorch_weights[map_name(name)]
|
||||
mx_val = np.array(mlx_weights[name])
|
||||
pt_val = pt.numpy()
|
||||
cos_sim = np.dot(mx_val.flat, pt_val.flat) / (
|
||||
np.linalg.norm(mx_val) * np.linalg.norm(pt_val) + 1e-10
|
||||
)
|
||||
if cos_sim < 0.9999:
|
||||
print(f"MISMATCH: {name} cos_sim={cos_sim:.6f}")
|
||||
```
|
||||
|
||||
### Conv3d → Linear reshaping
|
||||
|
||||
When converting 3D convolutions to linear layers (common for MLX which prefers
|
||||
linear ops), the flattening order must match:
|
||||
|
||||
```python
|
||||
# PyTorch Conv3d weight: (out_ch, in_ch, kT, kH, kW)
|
||||
# Flatten to Linear: (out_ch, in_ch * kT * kH * kW)
|
||||
# The reshape order MUST match how the input is patchified
|
||||
```
|
||||
|
||||
### Sanitization functions
|
||||
|
||||
Write explicit weight sanitization that maps reference key names to your key names.
|
||||
Don't rely on automatic matching — key naming conventions differ between frameworks.
|
||||
|
||||
### Module structure must mirror reference for direct loading
|
||||
|
||||
**Bug** (Wan2.2): Rewrote entire VAE module hierarchy to match PyTorch `nn.Sequential`
|
||||
structure. ResidualBlock needed `None` gaps at specific indices to match the original
|
||||
`nn.Sequential(RMSNorm, SiLU, Conv3d, ...)` indexing.
|
||||
|
||||
When possible, structure your modules to accept reference weights directly without
|
||||
sanitization. This eliminates an entire class of bugs.
|
||||
|
||||
### Save VAE weights in float32
|
||||
|
||||
Even if the model uses bfloat16 for the transformer, save VAE weights in float32.
|
||||
bfloat16 → float32 roundtrip loses precision that cannot be recovered by load-time
|
||||
upcast.
|
||||
|
||||
### Temporal downsample/upsample order
|
||||
|
||||
**Bug** (Wan2.2): `temporal_downsample=[True, True, False]` but reference uses
|
||||
`[False, True, True]`. Stage 0 created a `time_conv` with random weights (no matching
|
||||
file key), and Stage 2 missed its `time_conv` (weights silently dropped).
|
||||
|
||||
Always verify boolean flags for each stage by inspecting the actual weight file keys.
|
||||
|
||||
### Silent weight drops are the worst bugs
|
||||
|
||||
When `load_weights()` with `strict=False` silently skips keys that don't match, you
|
||||
get a model with random weights for those layers. This produces output that looks
|
||||
"almost right" but is subtly wrong. Always log which keys were loaded vs skipped:
|
||||
|
||||
```python
|
||||
loaded_keys = set()
|
||||
for key, value in weights:
|
||||
if key in model_params:
|
||||
loaded_keys.add(key)
|
||||
# Check for missing
|
||||
expected = set(model_params.keys())
|
||||
missing = expected - loaded_keys
|
||||
if missing:
|
||||
print(f"WARNING: {len(missing)} weights not loaded: {list(missing)[:5]}...")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Text Conditioning Failures
|
||||
|
||||
### Symptom: model predicts noise back to itself
|
||||
|
||||
If the model output correlates > 0.8 with its input noise, text conditioning is
|
||||
likely broken. The model has learned nothing from the prompt and is just returning
|
||||
its input.
|
||||
|
||||
### Check embedding statistics
|
||||
|
||||
```python
|
||||
text_emb = text_encoder(prompt)
|
||||
print(f"text_emb: mean={text_emb.mean():.4f} std={text_emb.std():.4f}")
|
||||
# std < 0.01 → embeddings are collapsed → broken encoder or wrong weights
|
||||
# std > 10.0 → embeddings are exploding → wrong normalization
|
||||
```
|
||||
|
||||
### Verify with ablation
|
||||
|
||||
```python
|
||||
# Generate with real text
|
||||
output_text = denoise(latent, text_emb=real_embeddings)
|
||||
# Generate with zeros
|
||||
output_zero = denoise(latent, text_emb=mx.zeros_like(real_embeddings))
|
||||
# Compare
|
||||
text_influence = np.mean(np.abs(output_text - output_zero))
|
||||
print(f"Text influence: {text_influence:.4f}") # Should be > 0 (typically 30-60% of output)
|
||||
```
|
||||
|
||||
### Text preprocessing must match exactly
|
||||
|
||||
**Bug** (Wan2.2): Patchy-blurry output from wrong negative prompt tokenization.
|
||||
|
||||
The official Wan2.2 tokenizer applies `ftfy.fix_text` + `html.unescape` + whitespace
|
||||
normalization before tokenization. Without this, fullwidth Chinese commas (U+FF0C)
|
||||
tokenize differently from ASCII commas (U+002C), causing **27 different token IDs**
|
||||
in the negative prompt. This made CFG's unconditional prediction wrong.
|
||||
|
||||
**Fix**: Apply the same text cleaning pipeline as the reference:
|
||||
```python
|
||||
import ftfy
|
||||
import html
|
||||
import re
|
||||
|
||||
def clean_text(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
```
|
||||
|
||||
### T5 encoder precision
|
||||
|
||||
**Bug** (Wan2.2): Quality degradation from bfloat16 T5 attention.
|
||||
|
||||
T5 uses **no scaling** in attention (no `1/sqrt(d)` factor), so attention logits can
|
||||
be very large. bfloat16 softmax loses significant precision across 24 encoder layers.
|
||||
|
||||
**Fix**: Compute T5 QK^T and softmax in float32. The T5 encoder only runs once per
|
||||
generation, so the performance impact is negligible.
|
||||
|
||||
### Dual-model text embeddings
|
||||
|
||||
**Bug** (Wan2.2 I2V-14B): Low/high noise models have different `text_embedding` weights
|
||||
(~42% relative difference). Using one model's embeddings for both caused incorrect
|
||||
text conditioning for the high-noise model that handles critical early denoising steps.
|
||||
|
||||
**Fix**: Compute separate text embeddings for each model in dual-model setups.
|
||||
|
||||
---
|
||||
|
||||
## 9. Position Encodings (RoPE)
|
||||
|
||||
### Multi-scale consistency
|
||||
|
||||
In pyramid/multi-resolution models, RoPE must be computed consistently across scales.
|
||||
If the model operates at 1/4 resolution in an early stage, the position grid must
|
||||
reflect the actual spatial dimensions, not the final target dimensions.
|
||||
|
||||
### History vs current chunk
|
||||
|
||||
When conditioning on history from a previous chunk, the position encoding for
|
||||
history frames must match what the model saw during training. Mismatches between
|
||||
history and current-chunk position encodings can cause subtle spatial distortions
|
||||
that compound across chunks.
|
||||
|
||||
### Factorized RoPE
|
||||
|
||||
3D video models often use factorized RoPE (separate temporal, height, width
|
||||
frequencies). Verify each axis independently:
|
||||
|
||||
```python
|
||||
# Compare temporal frequencies
|
||||
assert np.allclose(mlx_rope_t, ref_rope_t, atol=1e-5)
|
||||
# Compare spatial frequencies
|
||||
assert np.allclose(mlx_rope_h, ref_rope_h, atol=1e-5)
|
||||
assert np.allclose(mlx_rope_w, ref_rope_w, atol=1e-5)
|
||||
```
|
||||
|
||||
### Per-axis frequency construction
|
||||
|
||||
**Bug** (Wan2.2): Grey/artifact-filled output from wrong frequency distribution.
|
||||
|
||||
The reference uses three separate `rope_params()` calls with different dimension
|
||||
normalizations (e.g., 44, 42, 42 for Wan) so each axis gets its own full frequency
|
||||
range. Consolidating into a single `rope_params(head_dim)` call and splitting gave
|
||||
height frequencies starting at 0.042 and width at 0.002 (should be 1.0 for both).
|
||||
|
||||
**Fix** (and subsequent revert): This bug was introduced as a "fix" for a previous
|
||||
RoPE issue, then had to be reverted. The lesson: RoPE changes have far-reaching effects.
|
||||
Always verify with actual generation, not just numerical comparison of frequencies.
|
||||
|
||||
**Lesson**: Read the reference's frequency construction very carefully. Don't
|
||||
"simplify" three separate calls into one unless you verify the frequency distribution
|
||||
matches exactly.
|
||||
|
||||
---
|
||||
|
||||
## 10. Multi-Stage / Pyramid Pipelines
|
||||
|
||||
### Each stage is a potential failure point
|
||||
|
||||
Pyramid pipelines (generate at low res, upsample, refine at high res) multiply the
|
||||
number of things that can go wrong:
|
||||
|
||||
- Downsampling method (bilinear vs area) must match reference
|
||||
- Energy compensation factors (e.g., ×2 after bilinear downsample) must be present
|
||||
- Alpha/beta noise mixing coefficients are stage-dependent
|
||||
- Frame indices and history resolution change per stage
|
||||
|
||||
### Test single-stage first
|
||||
|
||||
If the model works at full resolution for a single stage but fails in the pyramid,
|
||||
the bug is in stage orchestration — typically in how latents are passed between
|
||||
stages or how position encodings adapt to different resolutions.
|
||||
|
||||
### Integration bugs are the hardest
|
||||
|
||||
We verified every Helios component matched the reference individually, but the
|
||||
pyramid still produced uniform color. The bug was in dtype handling during stage
|
||||
transitions. Integration bugs only appear when components interact.
|
||||
|
||||
---
|
||||
|
||||
## 11. Common Symptoms → Root Causes
|
||||
|
||||
| Symptom | Likely Root Causes |
|
||||
|---------|-------------------|
|
||||
| **Pure noise output** | Wrong sigma schedule, broken text conditioning, incorrect weight mapping |
|
||||
| **Uniform color** | Model predicting noise back; text embeddings collapsed; wrong timestep format |
|
||||
| **Progressive zoom/shrink** | bfloat16 residuals truncating high-freq detail; RoPE mismatch across chunks |
|
||||
| **Brightness jumps at boundaries** | VAE causal warmup; cross-fade blending misaligned content |
|
||||
| **Color drift across chunks** | Dtype in scheduler step; history normalization missing |
|
||||
| **Blur at boundaries** | Cross-fade enabled; latent blending; wrong VAE decode order |
|
||||
| **Grid/checker patterns** | Patchify channel ordering bug; latent blend artifacts; reshape axis error |
|
||||
| **Green/magenta tint** | VAE weight key mismatch; wrong denormalization constants; cv2 YUV color matrix |
|
||||
| **Mean drift across steps** | bfloat16 accumulation; wrong scheduler formula; missing energy compensation |
|
||||
| **Garbled/scrambled output** | Silent weight drops (underscore prefix, wrong key mapping); RMS vs L2 norm |
|
||||
| **Greenish-yellow constant** | Scheduler boundary condition (log(0) not returning -inf); x0 overscaling |
|
||||
| **~2x slower than expected** | Float32 promotion cascade from single mistyped intermediate |
|
||||
| **Extra output frames** | Causal padding producing extra temporal frames; missing `first_chunk` trim |
|
||||
| **Grey/artifact output** | RoPE frequency construction wrong (per-axis vs single-call) |
|
||||
| **Patchy-blurry with CFG** | Text preprocessing mismatch (fullwidth vs ASCII chars → wrong tokenization) |
|
||||
| **I2V temporal mismatch** | Non-chunked VAE encoding vs reference's chunked encoding with temporal cache |
|
||||
|
||||
---
|
||||
|
||||
## 12. Verification Checklist
|
||||
|
||||
Use this checklist when porting a new diffusion video model:
|
||||
|
||||
### Model
|
||||
- [ ] Weight conversion: all keys mapped, cosine similarity > 0.9999
|
||||
- [ ] No silent weight drops (log loaded vs expected keys)
|
||||
- [ ] Single forward pass matches reference (cos_sim > 0.999)
|
||||
- [ ] Residual connections use float32 accumulation
|
||||
- [ ] Attention computation matches reference precision
|
||||
- [ ] Modulation/gate vectors in float32 (if reference uses autocast)
|
||||
- [ ] No underscore-prefixed module attributes (MLX ignores them)
|
||||
|
||||
### Scheduler
|
||||
- [ ] Sigma values match reference at every step (diff < 1e-6)
|
||||
- [ ] Timestep format correct (int vs float, scale factor)
|
||||
- [ ] Dynamic shifting formula copied exactly
|
||||
- [ ] Step function returns correct dtype (float32)
|
||||
- [ ] Boundary conditions: lambda(-inf) at sigma=1, lambda(+inf) at sigma=0
|
||||
- [ ] Higher-order coefficients computed (not hardcoded) for UniPC/DPM++
|
||||
|
||||
### Text Encoder
|
||||
- [ ] Embedding statistics reasonable (0.01 < std < 10)
|
||||
- [ ] Text influence > 0 (ablation test)
|
||||
- [ ] Tokenization matches (special tokens, padding, max length)
|
||||
- [ ] Text preprocessing matches (ftfy, html unescape, whitespace normalization)
|
||||
- [ ] T5/CLIP attention precision (float32 softmax if no 1/sqrt(d) scaling)
|
||||
- [ ] Separate embeddings for dual-model setups (if applicable)
|
||||
|
||||
### VAE
|
||||
- [ ] Denormalization constants match training pipeline
|
||||
- [ ] Per-chunk decoding (not concatenated)
|
||||
- [ ] Temporal frame count correct (account for causal padding)
|
||||
- [ ] Weight keys mapped correctly (encoder vs decoder)
|
||||
- [ ] Weights stored/loaded in float32 (not bfloat16)
|
||||
- [ ] Temporal downsample/upsample order matches reference
|
||||
- [ ] RMS_norm vs L2_norm: check actual implementation, not class name
|
||||
- [ ] Chunked encoding for I2V (if applicable)
|
||||
- [ ] Reshape axis ordering correct ([B,C,T,H,W] → transpose before reshape)
|
||||
|
||||
### Pipeline Orchestration
|
||||
- [ ] Position encodings consistent across stages/chunks
|
||||
- [ ] History slicing and conditioning correct
|
||||
- [ ] Noise generation matches (distribution, correlation structure)
|
||||
- [ ] Multi-chunk output visually consistent (no progressive degradation)
|
||||
- [ ] Dimension auto-alignment (divisible by patch_size × vae_stride)
|
||||
- [ ] Dtype-aware padding (mx.zeros with explicit dtype)
|
||||
|
||||
### Output
|
||||
- [ ] Frame count matches expected (account for warmup/trim)
|
||||
- [ ] FPS correct
|
||||
- [ ] Color range [0, 255] uint8 for video
|
||||
- [ ] No first-frame duplication artifacts
|
||||
- [ ] Video codec correct (imageio/libx264 preferred over cv2/mp4v on macOS)
|
||||
|
||||
### Performance
|
||||
- [ ] No float32 promotion cascades (check with profiler)
|
||||
- [ ] Using mx.fast kernels (rms_norm, layer_norm, sdpa)
|
||||
- [ ] Time embedding computed once per sample (not per position)
|
||||
- [ ] Memory cleanup (delete temporaries before mx.eval)
|
||||
|
||||
---
|
||||
|
||||
## 13. Diagnostic Tools
|
||||
|
||||
### General video diagnostics (`scripts/video/`)
|
||||
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| `compare_videos.py` | PSNR, SSIM, temporal coherence, color fidelity between two videos |
|
||||
| `video_quality.py` | Sharpness, stability, defect detection, chunk boundary analysis |
|
||||
|
||||
```bash
|
||||
# Quick quality check
|
||||
python scripts/video/video_quality.py output.mp4 --chunk-size 32
|
||||
|
||||
# Compare against reference
|
||||
python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4
|
||||
```
|
||||
|
||||
### Model-specific diagnostics (`scripts/helios/`)
|
||||
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| `analyze_boundaries.py` | Detailed boundary quality metrics for Helios |
|
||||
| `run_reference.py` | Run PyTorch reference on MPS |
|
||||
| `compare_pipelines.py` | Compare scheduler/pipeline mechanics |
|
||||
| `compare_models.py` | Cross-framework model output comparison |
|
||||
|
||||
### Inline debugging pattern
|
||||
|
||||
Add temporary debug output to the diffusion loop:
|
||||
|
||||
```python
|
||||
for i, sigma in enumerate(sigmas):
|
||||
flow = model(latent, sigma, text_emb)
|
||||
latent = scheduler.step(latent, flow, sigma, sigma_next)
|
||||
|
||||
# Debug: track statistics
|
||||
print(f"[step {i}] sigma={sigma:.4f} "
|
||||
f"latent: mean={latent.mean():.6f} std={latent.std():.6f} "
|
||||
f"flow: mean={flow.mean():.6f} std={flow.std():.6f}")
|
||||
|
||||
# Debug: save for cross-framework comparison
|
||||
if os.environ.get("DEBUG"):
|
||||
mx.save(f"/tmp/debug_step_{i}.npz", {
|
||||
"latent": latent, "flow": flow, "sigma": mx.array(sigma)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **Precision is the #1 bug source** — bfloat16 residuals, scheduler math, type
|
||||
promotion, modulation vectors. Copy the reference's `.float()` and `autocast` zones.
|
||||
|
||||
2. **Don't add what the reference doesn't have** — cross-fade, overlap decode,
|
||||
temporal blending. If the reference works without it, you probably have a bug
|
||||
elsewhere.
|
||||
|
||||
3. **Silent failures are the hardest bugs** — underscore-prefixed weights, `strict=False`
|
||||
weight loading, wrong normalization class names. Always verify weight load counts
|
||||
and output statistics.
|
||||
|
||||
4. **Component isolation → integration testing** — verify each part matches, then
|
||||
debug their interaction.
|
||||
|
||||
5. **Statistical comparison beats visual inspection** — mean drift, contrast ratios,
|
||||
and cosine similarity catch bugs before they're visible.
|
||||
|
||||
6. **Autoregressive errors compound** — a 1% error per chunk becomes 10% by chunk 10.
|
||||
Fix precision first, add corrections second.
|
||||
|
||||
7. **MLX has unique pitfalls** — underscore attribute names, reshape axis ordering,
|
||||
dtype-unaware padding, and float32 promotion cascades. Know your framework.
|
||||
|
||||
8. **Text preprocessing matters** — Unicode normalization, fullwidth chars, HTML entities.
|
||||
A single mismatched comma can break CFG guidance.
|
||||
|
||||
9. **VAE is deceptively complex** — causal padding, temporal frame counts, chunked vs
|
||||
batch processing, norm implementations. Budget significant debugging time for VAE.
|
||||
Reference in New Issue
Block a user