34 KiB
Porting Diffusion Video Models to MLX: Lessons Learned
A practical guide distilled from porting Wan2.1/2.2 (1.3B–14B) and Helios 14B DiT video generation models from PyTorch to MLX on Apple Silicon. These lessons apply broadly to any diffusion-based video (or image) model port.
Table of Contents
- Debugging Methodology
- Precision & Dtype Pitfalls
- MLX-Specific Gotchas
- Autoregressive Chunk Boundaries
- VAE Decoder Artifacts
- Scheduler & Timestep Issues
- Weight Conversion
- Text Conditioning Failures
- Position Encodings (RoPE)
- Multi-Stage / Pyramid Pipelines
- Common Symptoms → Root Causes
- Verification Checklist
- Diagnostic Tools
1. Debugging Methodology
Component isolation first
Never debug the full pipeline. Test each component in isolation:
- Text encoder — Does it produce embeddings with reasonable statistics? (std > 0.01)
- Scheduler — Do sigma/timestep values match the reference exactly?
- Transformer — Does a single forward pass match the reference? (cosine similarity > 0.999)
- VAE decoder — Feed reference latents into your VAE. Does the output look correct?
If every component matches individually but the pipeline fails, the bug is in orchestration — how components are wired together.
Statistical fingerprinting
Track per-step statistics through the diffusion loop:
# After each denoising step
print(f"step {i}: mean={latent.mean():.6f} std={latent.std():.6f} "
f"min={latent.min():.4f} max={latent.max():.4f}")
What to look for:
- Progressive mean drift (e.g., -0.002 → -0.040 → -0.123) signals accumulating errors
- Collapsing std (std dropping toward 0) signals broken conditioning or wrong noise schedule
- Exploding values signal wrong sigma scaling or scheduler formula
Cross-framework numerical comparison
The most powerful debugging tool: save intermediate tensors from your MLX pipeline, feed them to the PyTorch reference, compare outputs.
# In MLX pipeline, save inputs before transformer call
mx.save("debug_inputs.npz", {"latent": latent, "timestep": t, "text_emb": text_emb})
# In PyTorch script, load and compare
inputs = np.load("debug_inputs.npz")
mlx_out = np.load("debug_output.npz")["flow"]
pt_out = reference_model(torch.from_numpy(inputs["latent"]), ...)
cos_sim = F.cosine_similarity(pt_out.flatten(), torch.from_numpy(mlx_out).flatten(), dim=0)
# cos_sim > 0.999 = model is correct; bug is elsewhere
# cos_sim < 0.99 = model has a bug; compare per-layer
Ablation testing
When a pipeline has multiple "fixes" or features, disable them one at a time:
- Frozen history: Fix history to the same value for all chunks → proves whether history propagation is the source of drift/zoom
- Single chunk: Generate only 1 chunk → isolates per-chunk quality from multi-chunk interaction bugs
- Disable post-processing: Remove cross-fade, blending, corrections → reveals what the raw model output looks like
Use reference on same hardware
Run the PyTorch reference on the same device (MPS for Apple Silicon). CUDA and MPS produce different numerical results due to different float handling. Comparing your MLX output against a CUDA reference adds noise to the comparison.
# MPS may not support float64 — patch the reference:
original_linspace = torch.linspace
def patched_linspace(*args, **kwargs):
kwargs.pop("dtype", None)
return original_linspace(*args, dtype=torch.float32, **kwargs)
torch.linspace = patched_linspace
2. Precision & Dtype Pitfalls
The #1 source of subtle bugs
Precision issues caused the most insidious bugs in our port. They don't cause crashes — they cause progressive quality degradation that's hard to attribute.
Residual connections MUST be float32
Bug: Progressive zoom/shrinking across autoregressive chunks.
Root cause: Residual additions (x = x + attn_out) in bfloat16. With 7-bit
mantissa, high-frequency spatial detail is systematically truncated. Over 144
residual ops × 6+ model calls per chunk, detail is progressively smoothed away.
Fix: Promote to float32 for the addition:
# BAD — bfloat16 accumulation
x = x + attn_out
# GOOD — match reference's .float() pattern
x = (x.astype(mx.float32) + attn_out).astype(weight_dtype)
Rule: If the reference uses .float() anywhere, copy that pattern exactly. It's
there for a reason, even if a quick test seems to work without it.
Scheduler computations need high precision
Diffusion schedulers involve:
x0 = xt - sigma * flow— catastrophic cancellation near sigma ≈ 1log(sigma)andexp()— sensitive to small precision differences
Some references use float64 for these computations. MLX GPU doesn't support float64, so use float32 and accept small numerical differences, but never use bfloat16 for scheduler math.
Dtype propagation is invisible
Track dtype through your pipeline. A single bfloat16 intermediate can silently downcast everything downstream:
# This looks harmless but if model output is bfloat16:
result = noise - sigma * model_output # result is bfloat16!
# Fix: explicit cast
result = (noise.astype(mx.float32) - sigma * model_output.astype(mx.float32))
Type promotion rules differ across frameworks
- PyTorch: bfloat16 + float32 → float32
- MLX: bfloat16 + float32 → float32 (same, but verify)
- NumPy: no bfloat16 support
Always check what your framework does and match the reference's implicit promotations.
Float32 for VAE decoding
Bug (Wan2.2): VAE decode in bfloat16 produced visibly worse quality than reference.
Official Wan2.2 runs VAE decode in torch.float (float32), but our converted weights
were bfloat16. The VAE has many sequential layers where precision loss compounds.
Fix: Upcast VAE weights to float32 at load time. The VAE runs once per generation, so the performance impact is negligible compared to the transformer.
Modulation/gate vectors need float32
Bug (Wan2.2): Quality degradation from bfloat16 modulation across 30 blocks × 50 steps.
The official Wan2.2 explicitly uses torch.amp.autocast('cuda', dtype=torch.float32)
for time embeddings, modulation parameters, norm outputs before modulation, and gate ops.
Fix: Keep modulation in float32, cast to working dtype only when applying to the hidden state:
# Modulation computed in float32
e0 = self.modulation(time_emb) # float32
scale, shift, gate = e0.split(3, axis=-1)
# Cast to bfloat16 only for the matmul with hidden state
x = (x * (1 + scale.astype(x.dtype)) + shift.astype(x.dtype))
Map PyTorch autocast zones precisely
PyTorch models use nested torch.amp.autocast scopes to switch precision. Map these
exactly:
- Outer scope (
bfloat16): attention QKV projections, FFN matmuls - Inner scope (
float32): modulation, gates, norms, RoPE - Residual stream: float32 (the "backbone" between blocks)
# Wan2.2 dtype flow (matches official):
# Modulation/gates: float32 (explicit)
# QKV/FFN linear projections: bfloat16 (weight dtype)
# RoPE: float32 (official uses float64, MLX lacks float64)
# Attention Q/K: cast back to bfloat16 after RoPE
# Residual stream: float32
Float32 promotion cascades kill performance
Bug (Wan2.2): ~2x slowdown from accidental float32 promotion.
A single float32 tensor (e.g., time embedding) flowing into bfloat16 operations promotes the entire computation graph to float32. In Wan2.2:
- Time embedding MLP output (float32) fed into transformer → all layers float32
- RoPE frequencies (float32) applied to Q/K → all attention float32
Fix: Cast intermediate results to model dtype at promotion boundaries:
# After time embedding MLP (float32), cast before feeding to transformer
time_emb = time_mlp(t).astype(model_dtype)
# After RoPE (float32), cast Q/K back to attention dtype
q = rope_apply(q, freqs).astype(v.dtype)
3. MLX-Specific Gotchas
Underscore-prefixed attributes are invisible
Bug (Wan2.2): 87 of 110 VAE weights silently dropped during loading.
MLX's nn.Module.parameters() and nn.Module.load_weights() skip attributes
whose names start with underscore. If you name a layer self._layer_0, its weights
will never be loaded or saved.
# BAD — weights silently ignored
self._layer_0 = nn.Linear(...) # nn.Module skips _prefixed attrs
# GOOD
self.layer_0 = nn.Linear(...)
This is especially insidious because there's no error — the model loads, runs, and produces output. The output is just garbage because most weights are random.
nn.Sequential indexing vs named children
PyTorch's nn.Sequential uses integer indices (sequential.0.weight), while MLX's
module hierarchy uses named attributes. When mirroring a PyTorch module structure,
you need explicit key sanitization:
def sanitize_key(key):
# PyTorch: "decoder.middle.0.residual.1.weight"
# MLX: "decoder.middle.layer_0.residual.layer_1.weight"
key = re.sub(r'\.(\d+)', lambda m: f'.layer_{m.group(1)}', key)
return key
Reshape axis ordering differs from PyTorch
Bug (Wan2.2): Green checkerboard pattern from VAE attention.
[B,C,T,H,W] cannot be directly reshaped to [BT,C,H,W] because in memory C
comes before T. PyTorch's reshape works because it handles non-contiguous tensors.
MLX requires explicit transpose first:
# BAD — mixes channels with time
x = x.reshape(B*T, C, H, W) # Corrupts spatial layout
# GOOD — make B,T adjacent first
x = x.transpose(0, 2, 1, 3, 4) # [B,T,C,H,W]
x = x.reshape(B*T, C, H, W) # Now correct
Patchify channel ordering
Bug (Wan2.2): Solid green video output from wrong patchify order.
When converting a Conv3d patchify to a manual reshape+linear, the dimension ordering
in the reshape must match the Conv3d weight layout. Conv3d expects [C, pt, ph, pw]
(channels slowest), but a naive reshape produces [pt, ph, pw, C] (channels fastest):
# BAD — channel scrambling
patches = x.reshape(B, F', H', W', pt, ph, pw, C)
# GOOD — match Conv3d weight layout
patches = x.reshape(B, F', pt, H', ph, W', pw, C)
patches = patches.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, F', H', W', C, pt, ph, pw]
Verify numerically: the fixed version should match Conv3d output to ~1e-6.
mx.zeros / padding inherits dtype
Use dtype-aware mx.zeros for padding and concatenation to avoid promotion:
# BAD — default float32 padding promotes bfloat16 input
pad = mx.zeros((B, pad_len, C)) # float32!
x = mx.concatenate([pad, x], axis=1) # x promoted to float32
# GOOD — match input dtype
pad = mx.zeros((B, pad_len, C), dtype=x.dtype)
x = mx.concatenate([pad, x], axis=1) # stays bfloat16
Use mx.fast kernels
Replace manual implementations with fused MLX kernels where possible:
# Manual RMS norm → mx.fast.rms_norm
# Manual LayerNorm → mx.fast.layer_norm
# Manual attention → mx.fast.scaled_dot_product_attention
These are faster and handle precision internally.
4. Autoregressive Chunk Boundaries
For models that generate long videos by autoregressively extending chunks (Helios, CogVideoX, etc.), chunk boundaries are the primary source of visual artifacts.
Don't add post-processing the reference doesn't have
Bug: Added pixel cross-fade to smooth boundaries → caused 40% sharpness drop.
The reference pipeline used no cross-fade at all. The first frame of each new chunk is intentionally a sharp reconstruction conditioned on history. Blending it with the previous chunk's tail (which has different content) creates blur.
Rule: Before adding smoothing/blending, verify the reference doesn't do it. Reference simplicity is usually correct.
First-frame artifacts are common
The first pixel frame of each non-first chunk is typically a distorted reconstruction of the conditioning frame. In many models, this is expected behavior:
- Fix: Drop the first frame from each chunk
- Verify frame math: If 33 raw frames at 16fps → drop 1 → 32 frames = exactly 2 seconds
History conditioning errors compound
Small errors in how history is prepared, sliced, patchified, or position-encoded will compound across chunks. The error is invisible in chunk 1, small in chunk 2, and catastrophic by chunk 5.
Debug strategy: Generate with frozen history (same history for every chunk). If the artifact disappears, the bug is in history handling.
5. VAE Decoder Artifacts
Causal temporal convolutions cause boundary warmup
Video VAEs (WanVAE, CogVideoX-VAE) use causal temporal convolutions. When decoding each chunk independently, the first few frames lack temporal context (only zero padding), causing:
- ~7% contrast drop in first frames of each chunk
- Spatial brightness redistribution (face darkens, background brightens)
This is inherent to the architecture. The reference has the same effect but at lower magnitude.
Post-processing to fix VAE warmup
Two-stage correction applied to first N frames of each non-first chunk:
# Stage 1: Spatially-varying brightness correction
# Downsample reference (previous chunk's last frame) and current frame
ref_small = cv2.resize(ref_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
cur_small = cv2.resize(cur_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
diff_small = ref_small - cur_small
diff_full = cv2.resize(diff_small, (w, h), interpolation=cv2.INTER_LINEAR)
corrected = cur_frame + ramp * diff_full # ramp: 1.0 → 0.0 over N frames
# Stage 2: Per-channel contrast matching
for c in range(3):
ref_std = np.std(ref_frame[:,:,c])
cur_std = np.std(corrected[:,:,c])
scale = 1.0 + ramp * (ref_std / (cur_std + 1e-6) - 1.0)
corrected[:,:,c] = (corrected[:,:,c] - mean) * scale + mean
VAE overlap decode does NOT work
Attempted: Prepend previous chunk's last latent frames to give the decoder temporal context.
Result: Made things worse (22% contrast drop vs 7%). The causal convolutions see conflicting content from different chunks and create larger artifacts than zero-padding.
Lesson: Overlap only works when tiles contain the same content from the same denoising process (e.g., spatial tiling). It fails for temporal chunks with different content.
Per-chunk VAE decoding is correct
Decode each chunk's latents independently, not concatenated. Concatenating all chunks and decoding together lets boundary discontinuities propagate through temporal convolutions, creating worse artifacts.
First-frame quality: causal padding strategies
Multiple approaches were tried for the first-frame quality issue in Wan VAE:
| Approach | Result |
|---|---|
| Zero padding (default) | First ~4 frames degraded, but matches training |
| Replicate padding | Fixes artifacts but causes color intensity bias (conv applies all kernel weights to same value) |
| Warmup frame prepend | Helps motion but warmup frame itself has artifacts |
| Mirror-reflect warmup | Best compromise — varied context without zeros, no intensity bias |
Lesson: Don't assume "replicate padding is better than zero padding." The model was trained with zero padding; changing it shifts the gain. Instead, prepend warmup frames and trim them after decoding.
RMS_norm vs L2 normalization
Bug (Wan2.2): Garbled output from incorrect normalization.
A PyTorch class named RMS_norm actually uses F.normalize (L2 norm: x / ||x||_2),
not RMS normalization (x / sqrt(mean(x²))). The difference is a factor of sqrt(C),
causing values to explode through the decoder.
Lesson: Don't trust class names — read the actual implementation.
Temporal frame count: causal boundary effects
Bug (Wan2.2): VAE produced 12 frames instead of 9 for a 9-frame input.
PyTorch reference processes frames one-by-one with caching, skipping temporal conv for the first chunk. All-at-once decoding produces extra frames from zero-padded causal context.
Fix: Use first_chunk=True flag to trim causal boundary frames, matching the
reference's chunked behavior.
Chunked VAE encoding for I2V
Bug (Wan2.2 I2V-14B): Incorrect latents from non-chunked encoding.
Non-chunked encoding with causal zero-padding produces incorrect latents because temporal features don't propagate correctly without caching. The reference uses chunked encoding (1+4+4+... frames) with persistent temporal cache.
Fix: Implement chunked encoding with feat_cache propagation through CausalConv3d,
ResidualBlock, and Resample layers.
6. Scheduler & Timestep Issues
Copy formulas exactly
Even small differences in scheduler formulas compound over many steps:
# Dynamic time shifting — reference uses specific formula
mu = 0.5 + shift * 0.5 # NOT shift * 0.6 or any other constant
# Euler step
x_next = x + (sigma_next - sigma) * flow # order matters: next - current
Verify sigma schedules numerically
Print and compare sigma values at each step:
# Reference
sigmas_ref = [1.0, 0.99375, 0.9875, ...]
# Your implementation
sigmas = scheduler.get_sigmas(steps)
for i, (r, m) in enumerate(zip(sigmas_ref, sigmas)):
assert abs(r - m) < 1e-6, f"Step {i}: ref={r}, mlx={m}"
Timestep embedding precision
Integer vs float timesteps matter. Some models expect timestep=999 (int), others
expect timestep=0.999 (float). Wrong type can silently produce wrong embeddings
with reasonable-looking but incorrect statistics.
Boundary conditions: ±inf at sigma endpoints
Bug (Wan2.2): Greenish/yellow constant output from DPM++/UniPC schedulers.
The lambda(sigma) function must return -inf at sigma=1.0 (pure noise) and +inf
at sigma=0.0 (clean signal). Our implementation returned 0.0, causing massive x0
overscaling on the first denoising step.
PyTorch naturally computes torch.log(0) = -inf, and math.expm1(-inf) = -1.0
handles the formulas correctly. Reproduce this behavior explicitly:
def _lambda(self, sigma):
if sigma >= 1.0:
return float('-inf')
if sigma <= 0.0:
return float('inf')
return -math.log(sigma / (1 - sigma))
UniPC corrector coefficients
Bug (Wan2.2): Accumulated artifacts across 47+ steps from wrong polynomial weights.
The UniPC corrector must compute rhos_c via linalg.solve for order ≥ 2. Hardcoded
0.5 was 7× too large for the history weight (actual: ~0.08), causing massive
overweighting of history corrections.
7. Weight Conversion
Always verify statistically
After converting weights from PyTorch to MLX format:
for name in mlx_weights:
pt = pytorch_weights[map_name(name)]
mx_val = np.array(mlx_weights[name])
pt_val = pt.numpy()
cos_sim = np.dot(mx_val.flat, pt_val.flat) / (
np.linalg.norm(mx_val) * np.linalg.norm(pt_val) + 1e-10
)
if cos_sim < 0.9999:
print(f"MISMATCH: {name} cos_sim={cos_sim:.6f}")
Conv3d → Linear reshaping
When converting 3D convolutions to linear layers (common for MLX which prefers linear ops), the flattening order must match:
# PyTorch Conv3d weight: (out_ch, in_ch, kT, kH, kW)
# Flatten to Linear: (out_ch, in_ch * kT * kH * kW)
# The reshape order MUST match how the input is patchified
Sanitization functions
Write explicit weight sanitization that maps reference key names to your key names. Don't rely on automatic matching — key naming conventions differ between frameworks.
Module structure must mirror reference for direct loading
Bug (Wan2.2): Rewrote entire VAE module hierarchy to match PyTorch nn.Sequential
structure. ResidualBlock needed None gaps at specific indices to match the original
nn.Sequential(RMSNorm, SiLU, Conv3d, ...) indexing.
When possible, structure your modules to accept reference weights directly without sanitization. This eliminates an entire class of bugs.
Save VAE weights in float32
Even if the model uses bfloat16 for the transformer, save VAE weights in float32. bfloat16 → float32 roundtrip loses precision that cannot be recovered by load-time upcast.
Temporal downsample/upsample order
Bug (Wan2.2): temporal_downsample=[True, True, False] but reference uses
[False, True, True]. Stage 0 created a time_conv with random weights (no matching
file key), and Stage 2 missed its time_conv (weights silently dropped).
Always verify boolean flags for each stage by inspecting the actual weight file keys.
Silent weight drops are the worst bugs
When load_weights() with strict=False silently skips keys that don't match, you
get a model with random weights for those layers. This produces output that looks
"almost right" but is subtly wrong. Always log which keys were loaded vs skipped:
loaded_keys = set()
for key, value in weights:
if key in model_params:
loaded_keys.add(key)
# Check for missing
expected = set(model_params.keys())
missing = expected - loaded_keys
if missing:
print(f"WARNING: {len(missing)} weights not loaded: {list(missing)[:5]}...")
8. Text Conditioning Failures
Symptom: model predicts noise back to itself
If the model output correlates > 0.8 with its input noise, text conditioning is likely broken. The model has learned nothing from the prompt and is just returning its input.
Check embedding statistics
text_emb = text_encoder(prompt)
print(f"text_emb: mean={text_emb.mean():.4f} std={text_emb.std():.4f}")
# std < 0.01 → embeddings are collapsed → broken encoder or wrong weights
# std > 10.0 → embeddings are exploding → wrong normalization
Verify with ablation
# Generate with real text
output_text = denoise(latent, text_emb=real_embeddings)
# Generate with zeros
output_zero = denoise(latent, text_emb=mx.zeros_like(real_embeddings))
# Compare
text_influence = np.mean(np.abs(output_text - output_zero))
print(f"Text influence: {text_influence:.4f}") # Should be > 0 (typically 30-60% of output)
Text preprocessing must match exactly
Bug (Wan2.2): Patchy-blurry output from wrong negative prompt tokenization.
The official Wan2.2 tokenizer applies ftfy.fix_text + html.unescape + whitespace
normalization before tokenization. Without this, fullwidth Chinese commas (U+FF0C)
tokenize differently from ASCII commas (U+002C), causing 27 different token IDs
in the negative prompt. This made CFG's unconditional prediction wrong.
Fix: Apply the same text cleaning pipeline as the reference:
import ftfy
import html
import re
def clean_text(text):
text = ftfy.fix_text(text)
text = html.unescape(text)
text = re.sub(r'\s+', ' ', text).strip()
return text
T5 encoder precision
Bug (Wan2.2): Quality degradation from bfloat16 T5 attention.
T5 uses no scaling in attention (no 1/sqrt(d) factor), so attention logits can
be very large. bfloat16 softmax loses significant precision across 24 encoder layers.
Fix: Compute T5 QK^T and softmax in float32. The T5 encoder only runs once per generation, so the performance impact is negligible.
Dual-model text embeddings
Bug (Wan2.2 I2V-14B): Low/high noise models have different text_embedding weights
(~42% relative difference). Using one model's embeddings for both caused incorrect
text conditioning for the high-noise model that handles critical early denoising steps.
Fix: Compute separate text embeddings for each model in dual-model setups.
9. Position Encodings (RoPE)
Multi-scale consistency
In pyramid/multi-resolution models, RoPE must be computed consistently across scales. If the model operates at 1/4 resolution in an early stage, the position grid must reflect the actual spatial dimensions, not the final target dimensions.
History vs current chunk
When conditioning on history from a previous chunk, the position encoding for history frames must match what the model saw during training. Mismatches between history and current-chunk position encodings can cause subtle spatial distortions that compound across chunks.
Factorized RoPE
3D video models often use factorized RoPE (separate temporal, height, width frequencies). Verify each axis independently:
# Compare temporal frequencies
assert np.allclose(mlx_rope_t, ref_rope_t, atol=1e-5)
# Compare spatial frequencies
assert np.allclose(mlx_rope_h, ref_rope_h, atol=1e-5)
assert np.allclose(mlx_rope_w, ref_rope_w, atol=1e-5)
Per-axis frequency construction
Bug (Wan2.2): Grey/artifact-filled output from wrong frequency distribution.
The reference uses three separate rope_params() calls with different dimension
normalizations (e.g., 44, 42, 42 for Wan) so each axis gets its own full frequency
range. Consolidating into a single rope_params(head_dim) call and splitting gave
height frequencies starting at 0.042 and width at 0.002 (should be 1.0 for both).
Fix (and subsequent revert): This bug was introduced as a "fix" for a previous RoPE issue, then had to be reverted. The lesson: RoPE changes have far-reaching effects. Always verify with actual generation, not just numerical comparison of frequencies.
Lesson: Read the reference's frequency construction very carefully. Don't "simplify" three separate calls into one unless you verify the frequency distribution matches exactly.
10. Multi-Stage / Pyramid Pipelines
Each stage is a potential failure point
Pyramid pipelines (generate at low res, upsample, refine at high res) multiply the number of things that can go wrong:
- Downsampling method (bilinear vs area) must match reference
- Energy compensation factors (e.g., ×2 after bilinear downsample) must be present
- Alpha/beta noise mixing coefficients are stage-dependent
- Frame indices and history resolution change per stage
Test single-stage first
If the model works at full resolution for a single stage but fails in the pyramid, the bug is in stage orchestration — typically in how latents are passed between stages or how position encodings adapt to different resolutions.
Integration bugs are the hardest
We verified every Helios component matched the reference individually, but the pyramid still produced uniform color. The bug was in dtype handling during stage transitions. Integration bugs only appear when components interact.
11. Common Symptoms → Root Causes
| Symptom | Likely Root Causes |
|---|---|
| Pure noise output | Wrong sigma schedule, broken text conditioning, incorrect weight mapping |
| Uniform color | Model predicting noise back; text embeddings collapsed; wrong timestep format |
| Progressive zoom/shrink | bfloat16 residuals truncating high-freq detail; RoPE mismatch across chunks |
| Brightness jumps at boundaries | VAE causal warmup; cross-fade blending misaligned content |
| Color drift across chunks | Dtype in scheduler step; history normalization missing |
| Blur at boundaries | Cross-fade enabled; latent blending; wrong VAE decode order |
| Grid/checker patterns | Patchify channel ordering bug; latent blend artifacts; reshape axis error |
| Green/magenta tint | VAE weight key mismatch; wrong denormalization constants; cv2 YUV color matrix |
| Mean drift across steps | bfloat16 accumulation; wrong scheduler formula; missing energy compensation |
| Garbled/scrambled output | Silent weight drops (underscore prefix, wrong key mapping); RMS vs L2 norm |
| Greenish-yellow constant | Scheduler boundary condition (log(0) not returning -inf); x0 overscaling |
| ~2x slower than expected | Float32 promotion cascade from single mistyped intermediate |
| Extra output frames | Causal padding producing extra temporal frames; missing first_chunk trim |
| Grey/artifact output | RoPE frequency construction wrong (per-axis vs single-call) |
| Patchy-blurry with CFG | Text preprocessing mismatch (fullwidth vs ASCII chars → wrong tokenization) |
| I2V temporal mismatch | Non-chunked VAE encoding vs reference's chunked encoding with temporal cache |
12. Verification Checklist
Use this checklist when porting a new diffusion video model:
Model
- Weight conversion: all keys mapped, cosine similarity > 0.9999
- No silent weight drops (log loaded vs expected keys)
- Single forward pass matches reference (cos_sim > 0.999)
- Residual connections use float32 accumulation
- Attention computation matches reference precision
- Modulation/gate vectors in float32 (if reference uses autocast)
- No underscore-prefixed module attributes (MLX ignores them)
Scheduler
- Sigma values match reference at every step (diff < 1e-6)
- Timestep format correct (int vs float, scale factor)
- Dynamic shifting formula copied exactly
- Step function returns correct dtype (float32)
- Boundary conditions: lambda(-inf) at sigma=1, lambda(+inf) at sigma=0
- Higher-order coefficients computed (not hardcoded) for UniPC/DPM++
Text Encoder
- Embedding statistics reasonable (0.01 < std < 10)
- Text influence > 0 (ablation test)
- Tokenization matches (special tokens, padding, max length)
- Text preprocessing matches (ftfy, html unescape, whitespace normalization)
- T5/CLIP attention precision (float32 softmax if no 1/sqrt(d) scaling)
- Separate embeddings for dual-model setups (if applicable)
VAE
- Denormalization constants match training pipeline
- Per-chunk decoding (not concatenated)
- Temporal frame count correct (account for causal padding)
- Weight keys mapped correctly (encoder vs decoder)
- Weights stored/loaded in float32 (not bfloat16)
- Temporal downsample/upsample order matches reference
- RMS_norm vs L2_norm: check actual implementation, not class name
- Chunked encoding for I2V (if applicable)
- Reshape axis ordering correct ([B,C,T,H,W] → transpose before reshape)
Pipeline Orchestration
- Position encodings consistent across stages/chunks
- History slicing and conditioning correct
- Noise generation matches (distribution, correlation structure)
- Multi-chunk output visually consistent (no progressive degradation)
- Dimension auto-alignment (divisible by patch_size × vae_stride)
- Dtype-aware padding (mx.zeros with explicit dtype)
Output
- Frame count matches expected (account for warmup/trim)
- FPS correct
- Color range [0, 255] uint8 for video
- No first-frame duplication artifacts
- Video codec correct (imageio/libx264 preferred over cv2/mp4v on macOS)
Performance
- No float32 promotion cascades (check with profiler)
- Using mx.fast kernels (rms_norm, layer_norm, sdpa)
- Time embedding computed once per sample (not per position)
- Memory cleanup (delete temporaries before mx.eval)
13. Diagnostic Tools
General video diagnostics (scripts/video/)
| Script | Purpose |
|---|---|
compare_videos.py |
PSNR, SSIM, temporal coherence, color fidelity between two videos |
video_quality.py |
Sharpness, stability, defect detection, chunk boundary analysis |
# Quick quality check
python scripts/video/video_quality.py output.mp4 --chunk-size 32
# Compare against reference
python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4
Model-specific diagnostics (scripts/helios/)
| Script | Purpose |
|---|---|
analyze_boundaries.py |
Detailed boundary quality metrics for Helios |
run_reference.py |
Run PyTorch reference on MPS |
compare_pipelines.py |
Compare scheduler/pipeline mechanics |
compare_models.py |
Cross-framework model output comparison |
Inline debugging pattern
Add temporary debug output to the diffusion loop:
for i, sigma in enumerate(sigmas):
flow = model(latent, sigma, text_emb)
latent = scheduler.step(latent, flow, sigma, sigma_next)
# Debug: track statistics
print(f"[step {i}] sigma={sigma:.4f} "
f"latent: mean={latent.mean():.6f} std={latent.std():.6f} "
f"flow: mean={flow.mean():.6f} std={flow.std():.6f}")
# Debug: save for cross-framework comparison
if os.environ.get("DEBUG"):
mx.save(f"/tmp/debug_step_{i}.npz", {
"latent": latent, "flow": flow, "sigma": mx.array(sigma)
})
Key Takeaways
-
Precision is the #1 bug source — bfloat16 residuals, scheduler math, type promotion, modulation vectors. Copy the reference's
.float()andautocastzones. -
Don't add what the reference doesn't have — cross-fade, overlap decode, temporal blending. If the reference works without it, you probably have a bug elsewhere.
-
Silent failures are the hardest bugs — underscore-prefixed weights,
strict=Falseweight loading, wrong normalization class names. Always verify weight load counts and output statistics. -
Component isolation → integration testing — verify each part matches, then debug their interaction.
-
Statistical comparison beats visual inspection — mean drift, contrast ratios, and cosine similarity catch bugs before they're visible.
-
Autoregressive errors compound — a 1% error per chunk becomes 10% by chunk 10. Fix precision first, add corrections second.
-
MLX has unique pitfalls — underscore attribute names, reshape axis ordering, dtype-unaware padding, and float32 promotion cascades. Know your framework.
-
Text preprocessing matters — Unicode normalization, fullwidth chars, HTML entities. A single mismatched comma can break CFG guidance.
-
VAE is deceptively complex — causal padding, temporal frame counts, chunked vs batch processing, norm implementations. Budget significant debugging time for VAE.