Files
mlx-video/docs/PORTING-GUIDE.md

34 KiB
Raw Blame History

Porting Diffusion Video Models to MLX: Lessons Learned

A practical guide distilled from porting Wan2.1/2.2 (1.3B14B) and Helios 14B DiT video generation models from PyTorch to MLX on Apple Silicon. These lessons apply broadly to any diffusion-based video (or image) model port.


Table of Contents

  1. Debugging Methodology
  2. Precision & Dtype Pitfalls
  3. MLX-Specific Gotchas
  4. Autoregressive Chunk Boundaries
  5. VAE Decoder Artifacts
  6. Scheduler & Timestep Issues
  7. Weight Conversion
  8. Text Conditioning Failures
  9. Position Encodings (RoPE)
  10. Multi-Stage / Pyramid Pipelines
  11. Common Symptoms → Root Causes
  12. Verification Checklist
  13. Diagnostic Tools

1. Debugging Methodology

Component isolation first

Never debug the full pipeline. Test each component in isolation:

  1. Text encoder — Does it produce embeddings with reasonable statistics? (std > 0.01)
  2. Scheduler — Do sigma/timestep values match the reference exactly?
  3. Transformer — Does a single forward pass match the reference? (cosine similarity > 0.999)
  4. VAE decoder — Feed reference latents into your VAE. Does the output look correct?

If every component matches individually but the pipeline fails, the bug is in orchestration — how components are wired together.

Statistical fingerprinting

Track per-step statistics through the diffusion loop:

# After each denoising step
print(f"step {i}: mean={latent.mean():.6f} std={latent.std():.6f} "
      f"min={latent.min():.4f} max={latent.max():.4f}")

What to look for:

  • Progressive mean drift (e.g., -0.002 → -0.040 → -0.123) signals accumulating errors
  • Collapsing std (std dropping toward 0) signals broken conditioning or wrong noise schedule
  • Exploding values signal wrong sigma scaling or scheduler formula

Cross-framework numerical comparison

The most powerful debugging tool: save intermediate tensors from your MLX pipeline, feed them to the PyTorch reference, compare outputs.

# In MLX pipeline, save inputs before transformer call
mx.save("debug_inputs.npz", {"latent": latent, "timestep": t, "text_emb": text_emb})

# In PyTorch script, load and compare
inputs = np.load("debug_inputs.npz")
mlx_out = np.load("debug_output.npz")["flow"]
pt_out = reference_model(torch.from_numpy(inputs["latent"]), ...)
cos_sim = F.cosine_similarity(pt_out.flatten(), torch.from_numpy(mlx_out).flatten(), dim=0)
# cos_sim > 0.999 = model is correct; bug is elsewhere
# cos_sim < 0.99  = model has a bug; compare per-layer

Ablation testing

When a pipeline has multiple "fixes" or features, disable them one at a time:

  • Frozen history: Fix history to the same value for all chunks → proves whether history propagation is the source of drift/zoom
  • Single chunk: Generate only 1 chunk → isolates per-chunk quality from multi-chunk interaction bugs
  • Disable post-processing: Remove cross-fade, blending, corrections → reveals what the raw model output looks like

Use reference on same hardware

Run the PyTorch reference on the same device (MPS for Apple Silicon). CUDA and MPS produce different numerical results due to different float handling. Comparing your MLX output against a CUDA reference adds noise to the comparison.

# MPS may not support float64 — patch the reference:
original_linspace = torch.linspace
def patched_linspace(*args, **kwargs):
    kwargs.pop("dtype", None)
    return original_linspace(*args, dtype=torch.float32, **kwargs)
torch.linspace = patched_linspace

2. Precision & Dtype Pitfalls

The #1 source of subtle bugs

Precision issues caused the most insidious bugs in our port. They don't cause crashes — they cause progressive quality degradation that's hard to attribute.

Residual connections MUST be float32

Bug: Progressive zoom/shrinking across autoregressive chunks.

Root cause: Residual additions (x = x + attn_out) in bfloat16. With 7-bit mantissa, high-frequency spatial detail is systematically truncated. Over 144 residual ops × 6+ model calls per chunk, detail is progressively smoothed away.

Fix: Promote to float32 for the addition:

# BAD — bfloat16 accumulation
x = x + attn_out

# GOOD — match reference's .float() pattern
x = (x.astype(mx.float32) + attn_out).astype(weight_dtype)

Rule: If the reference uses .float() anywhere, copy that pattern exactly. It's there for a reason, even if a quick test seems to work without it.

Scheduler computations need high precision

Diffusion schedulers involve:

  • x0 = xt - sigma * flow — catastrophic cancellation near sigma ≈ 1
  • log(sigma) and exp() — sensitive to small precision differences

Some references use float64 for these computations. MLX GPU doesn't support float64, so use float32 and accept small numerical differences, but never use bfloat16 for scheduler math.

Dtype propagation is invisible

Track dtype through your pipeline. A single bfloat16 intermediate can silently downcast everything downstream:

# This looks harmless but if model output is bfloat16:
result = noise - sigma * model_output  # result is bfloat16!

# Fix: explicit cast
result = (noise.astype(mx.float32) - sigma * model_output.astype(mx.float32))

Type promotion rules differ across frameworks

  • PyTorch: bfloat16 + float32 → float32
  • MLX: bfloat16 + float32 → float32 (same, but verify)
  • NumPy: no bfloat16 support

Always check what your framework does and match the reference's implicit promotations.

Float32 for VAE decoding

Bug (Wan2.2): VAE decode in bfloat16 produced visibly worse quality than reference.

Official Wan2.2 runs VAE decode in torch.float (float32), but our converted weights were bfloat16. The VAE has many sequential layers where precision loss compounds.

Fix: Upcast VAE weights to float32 at load time. The VAE runs once per generation, so the performance impact is negligible compared to the transformer.

Modulation/gate vectors need float32

Bug (Wan2.2): Quality degradation from bfloat16 modulation across 30 blocks × 50 steps.

The official Wan2.2 explicitly uses torch.amp.autocast('cuda', dtype=torch.float32) for time embeddings, modulation parameters, norm outputs before modulation, and gate ops.

Fix: Keep modulation in float32, cast to working dtype only when applying to the hidden state:

# Modulation computed in float32
e0 = self.modulation(time_emb)  # float32
scale, shift, gate = e0.split(3, axis=-1)

# Cast to bfloat16 only for the matmul with hidden state
x = (x * (1 + scale.astype(x.dtype)) + shift.astype(x.dtype))

Map PyTorch autocast zones precisely

PyTorch models use nested torch.amp.autocast scopes to switch precision. Map these exactly:

  • Outer scope (bfloat16): attention QKV projections, FFN matmuls
  • Inner scope (float32): modulation, gates, norms, RoPE
  • Residual stream: float32 (the "backbone" between blocks)
# Wan2.2 dtype flow (matches official):
# Modulation/gates: float32 (explicit)
# QKV/FFN linear projections: bfloat16 (weight dtype)
# RoPE: float32 (official uses float64, MLX lacks float64)
# Attention Q/K: cast back to bfloat16 after RoPE
# Residual stream: float32

Float32 promotion cascades kill performance

Bug (Wan2.2): ~2x slowdown from accidental float32 promotion.

A single float32 tensor (e.g., time embedding) flowing into bfloat16 operations promotes the entire computation graph to float32. In Wan2.2:

  • Time embedding MLP output (float32) fed into transformer → all layers float32
  • RoPE frequencies (float32) applied to Q/K → all attention float32

Fix: Cast intermediate results to model dtype at promotion boundaries:

# After time embedding MLP (float32), cast before feeding to transformer
time_emb = time_mlp(t).astype(model_dtype)

# After RoPE (float32), cast Q/K back to attention dtype
q = rope_apply(q, freqs).astype(v.dtype)

3. MLX-Specific Gotchas

Underscore-prefixed attributes are invisible

Bug (Wan2.2): 87 of 110 VAE weights silently dropped during loading.

MLX's nn.Module.parameters() and nn.Module.load_weights() skip attributes whose names start with underscore. If you name a layer self._layer_0, its weights will never be loaded or saved.

# BAD — weights silently ignored
self._layer_0 = nn.Linear(...)  # nn.Module skips _prefixed attrs

# GOOD
self.layer_0 = nn.Linear(...)

This is especially insidious because there's no error — the model loads, runs, and produces output. The output is just garbage because most weights are random.

nn.Sequential indexing vs named children

PyTorch's nn.Sequential uses integer indices (sequential.0.weight), while MLX's module hierarchy uses named attributes. When mirroring a PyTorch module structure, you need explicit key sanitization:

def sanitize_key(key):
    # PyTorch: "decoder.middle.0.residual.1.weight"
    # MLX:    "decoder.middle.layer_0.residual.layer_1.weight"
    key = re.sub(r'\.(\d+)', lambda m: f'.layer_{m.group(1)}', key)
    return key

Reshape axis ordering differs from PyTorch

Bug (Wan2.2): Green checkerboard pattern from VAE attention.

[B,C,T,H,W] cannot be directly reshaped to [BT,C,H,W] because in memory C comes before T. PyTorch's reshape works because it handles non-contiguous tensors. MLX requires explicit transpose first:

# BAD — mixes channels with time
x = x.reshape(B*T, C, H, W)  # Corrupts spatial layout

# GOOD — make B,T adjacent first
x = x.transpose(0, 2, 1, 3, 4)  # [B,T,C,H,W]
x = x.reshape(B*T, C, H, W)     # Now correct

Patchify channel ordering

Bug (Wan2.2): Solid green video output from wrong patchify order.

When converting a Conv3d patchify to a manual reshape+linear, the dimension ordering in the reshape must match the Conv3d weight layout. Conv3d expects [C, pt, ph, pw] (channels slowest), but a naive reshape produces [pt, ph, pw, C] (channels fastest):

# BAD — channel scrambling
patches = x.reshape(B, F', H', W', pt, ph, pw, C)

# GOOD — match Conv3d weight layout
patches = x.reshape(B, F', pt, H', ph, W', pw, C)
patches = patches.transpose(0, 1, 3, 5, 7, 2, 4, 6)  # [B, F', H', W', C, pt, ph, pw]

Verify numerically: the fixed version should match Conv3d output to ~1e-6.

mx.zeros / padding inherits dtype

Use dtype-aware mx.zeros for padding and concatenation to avoid promotion:

# BAD — default float32 padding promotes bfloat16 input
pad = mx.zeros((B, pad_len, C))  # float32!
x = mx.concatenate([pad, x], axis=1)  # x promoted to float32

# GOOD — match input dtype
pad = mx.zeros((B, pad_len, C), dtype=x.dtype)
x = mx.concatenate([pad, x], axis=1)  # stays bfloat16

Use mx.fast kernels

Replace manual implementations with fused MLX kernels where possible:

# Manual RMS norm → mx.fast.rms_norm
# Manual LayerNorm → mx.fast.layer_norm
# Manual attention → mx.fast.scaled_dot_product_attention

These are faster and handle precision internally.


4. Autoregressive Chunk Boundaries

For models that generate long videos by autoregressively extending chunks (Helios, CogVideoX, etc.), chunk boundaries are the primary source of visual artifacts.

Don't add post-processing the reference doesn't have

Bug: Added pixel cross-fade to smooth boundaries → caused 40% sharpness drop.

The reference pipeline used no cross-fade at all. The first frame of each new chunk is intentionally a sharp reconstruction conditioned on history. Blending it with the previous chunk's tail (which has different content) creates blur.

Rule: Before adding smoothing/blending, verify the reference doesn't do it. Reference simplicity is usually correct.

First-frame artifacts are common

The first pixel frame of each non-first chunk is typically a distorted reconstruction of the conditioning frame. In many models, this is expected behavior:

  • Fix: Drop the first frame from each chunk
  • Verify frame math: If 33 raw frames at 16fps → drop 1 → 32 frames = exactly 2 seconds

History conditioning errors compound

Small errors in how history is prepared, sliced, patchified, or position-encoded will compound across chunks. The error is invisible in chunk 1, small in chunk 2, and catastrophic by chunk 5.

Debug strategy: Generate with frozen history (same history for every chunk). If the artifact disappears, the bug is in history handling.


5. VAE Decoder Artifacts

Causal temporal convolutions cause boundary warmup

Video VAEs (WanVAE, CogVideoX-VAE) use causal temporal convolutions. When decoding each chunk independently, the first few frames lack temporal context (only zero padding), causing:

  • ~7% contrast drop in first frames of each chunk
  • Spatial brightness redistribution (face darkens, background brightens)

This is inherent to the architecture. The reference has the same effect but at lower magnitude.

Post-processing to fix VAE warmup

Two-stage correction applied to first N frames of each non-first chunk:

# Stage 1: Spatially-varying brightness correction
# Downsample reference (previous chunk's last frame) and current frame
ref_small = cv2.resize(ref_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
cur_small = cv2.resize(cur_frame, (w//16, h//16), interpolation=cv2.INTER_AREA)
diff_small = ref_small - cur_small
diff_full = cv2.resize(diff_small, (w, h), interpolation=cv2.INTER_LINEAR)
corrected = cur_frame + ramp * diff_full  # ramp: 1.0 → 0.0 over N frames

# Stage 2: Per-channel contrast matching
for c in range(3):
    ref_std = np.std(ref_frame[:,:,c])
    cur_std = np.std(corrected[:,:,c])
    scale = 1.0 + ramp * (ref_std / (cur_std + 1e-6) - 1.0)
    corrected[:,:,c] = (corrected[:,:,c] - mean) * scale + mean

VAE overlap decode does NOT work

Attempted: Prepend previous chunk's last latent frames to give the decoder temporal context.

Result: Made things worse (22% contrast drop vs 7%). The causal convolutions see conflicting content from different chunks and create larger artifacts than zero-padding.

Lesson: Overlap only works when tiles contain the same content from the same denoising process (e.g., spatial tiling). It fails for temporal chunks with different content.

Per-chunk VAE decoding is correct

Decode each chunk's latents independently, not concatenated. Concatenating all chunks and decoding together lets boundary discontinuities propagate through temporal convolutions, creating worse artifacts.

First-frame quality: causal padding strategies

Multiple approaches were tried for the first-frame quality issue in Wan VAE:

Approach Result
Zero padding (default) First ~4 frames degraded, but matches training
Replicate padding Fixes artifacts but causes color intensity bias (conv applies all kernel weights to same value)
Warmup frame prepend Helps motion but warmup frame itself has artifacts
Mirror-reflect warmup Best compromise — varied context without zeros, no intensity bias

Lesson: Don't assume "replicate padding is better than zero padding." The model was trained with zero padding; changing it shifts the gain. Instead, prepend warmup frames and trim them after decoding.

RMS_norm vs L2 normalization

Bug (Wan2.2): Garbled output from incorrect normalization.

A PyTorch class named RMS_norm actually uses F.normalize (L2 norm: x / ||x||_2), not RMS normalization (x / sqrt(mean(x²))). The difference is a factor of sqrt(C), causing values to explode through the decoder.

Lesson: Don't trust class names — read the actual implementation.

Temporal frame count: causal boundary effects

Bug (Wan2.2): VAE produced 12 frames instead of 9 for a 9-frame input.

PyTorch reference processes frames one-by-one with caching, skipping temporal conv for the first chunk. All-at-once decoding produces extra frames from zero-padded causal context.

Fix: Use first_chunk=True flag to trim causal boundary frames, matching the reference's chunked behavior.

Chunked VAE encoding for I2V

Bug (Wan2.2 I2V-14B): Incorrect latents from non-chunked encoding.

Non-chunked encoding with causal zero-padding produces incorrect latents because temporal features don't propagate correctly without caching. The reference uses chunked encoding (1+4+4+... frames) with persistent temporal cache.

Fix: Implement chunked encoding with feat_cache propagation through CausalConv3d, ResidualBlock, and Resample layers.


6. Scheduler & Timestep Issues

Copy formulas exactly

Even small differences in scheduler formulas compound over many steps:

# Dynamic time shifting — reference uses specific formula
mu = 0.5 + shift * 0.5  # NOT shift * 0.6 or any other constant

# Euler step
x_next = x + (sigma_next - sigma) * flow  # order matters: next - current

Verify sigma schedules numerically

Print and compare sigma values at each step:

# Reference
sigmas_ref = [1.0, 0.99375, 0.9875, ...]

# Your implementation
sigmas = scheduler.get_sigmas(steps)
for i, (r, m) in enumerate(zip(sigmas_ref, sigmas)):
    assert abs(r - m) < 1e-6, f"Step {i}: ref={r}, mlx={m}"

Timestep embedding precision

Integer vs float timesteps matter. Some models expect timestep=999 (int), others expect timestep=0.999 (float). Wrong type can silently produce wrong embeddings with reasonable-looking but incorrect statistics.

Boundary conditions: ±inf at sigma endpoints

Bug (Wan2.2): Greenish/yellow constant output from DPM++/UniPC schedulers.

The lambda(sigma) function must return -inf at sigma=1.0 (pure noise) and +inf at sigma=0.0 (clean signal). Our implementation returned 0.0, causing massive x0 overscaling on the first denoising step.

PyTorch naturally computes torch.log(0) = -inf, and math.expm1(-inf) = -1.0 handles the formulas correctly. Reproduce this behavior explicitly:

def _lambda(self, sigma):
    if sigma >= 1.0:
        return float('-inf')
    if sigma <= 0.0:
        return float('inf')
    return -math.log(sigma / (1 - sigma))

UniPC corrector coefficients

Bug (Wan2.2): Accumulated artifacts across 47+ steps from wrong polynomial weights.

The UniPC corrector must compute rhos_c via linalg.solve for order ≥ 2. Hardcoded 0.5 was 7× too large for the history weight (actual: ~0.08), causing massive overweighting of history corrections.


7. Weight Conversion

Always verify statistically

After converting weights from PyTorch to MLX format:

for name in mlx_weights:
    pt = pytorch_weights[map_name(name)]
    mx_val = np.array(mlx_weights[name])
    pt_val = pt.numpy()
    cos_sim = np.dot(mx_val.flat, pt_val.flat) / (
        np.linalg.norm(mx_val) * np.linalg.norm(pt_val) + 1e-10
    )
    if cos_sim < 0.9999:
        print(f"MISMATCH: {name} cos_sim={cos_sim:.6f}")

Conv3d → Linear reshaping

When converting 3D convolutions to linear layers (common for MLX which prefers linear ops), the flattening order must match:

# PyTorch Conv3d weight: (out_ch, in_ch, kT, kH, kW)
# Flatten to Linear: (out_ch, in_ch * kT * kH * kW)
# The reshape order MUST match how the input is patchified

Sanitization functions

Write explicit weight sanitization that maps reference key names to your key names. Don't rely on automatic matching — key naming conventions differ between frameworks.

Module structure must mirror reference for direct loading

Bug (Wan2.2): Rewrote entire VAE module hierarchy to match PyTorch nn.Sequential structure. ResidualBlock needed None gaps at specific indices to match the original nn.Sequential(RMSNorm, SiLU, Conv3d, ...) indexing.

When possible, structure your modules to accept reference weights directly without sanitization. This eliminates an entire class of bugs.

Save VAE weights in float32

Even if the model uses bfloat16 for the transformer, save VAE weights in float32. bfloat16 → float32 roundtrip loses precision that cannot be recovered by load-time upcast.

Temporal downsample/upsample order

Bug (Wan2.2): temporal_downsample=[True, True, False] but reference uses [False, True, True]. Stage 0 created a time_conv with random weights (no matching file key), and Stage 2 missed its time_conv (weights silently dropped).

Always verify boolean flags for each stage by inspecting the actual weight file keys.

Silent weight drops are the worst bugs

When load_weights() with strict=False silently skips keys that don't match, you get a model with random weights for those layers. This produces output that looks "almost right" but is subtly wrong. Always log which keys were loaded vs skipped:

loaded_keys = set()
for key, value in weights:
    if key in model_params:
        loaded_keys.add(key)
# Check for missing
expected = set(model_params.keys())
missing = expected - loaded_keys
if missing:
    print(f"WARNING: {len(missing)} weights not loaded: {list(missing)[:5]}...")

8. Text Conditioning Failures

Symptom: model predicts noise back to itself

If the model output correlates > 0.8 with its input noise, text conditioning is likely broken. The model has learned nothing from the prompt and is just returning its input.

Check embedding statistics

text_emb = text_encoder(prompt)
print(f"text_emb: mean={text_emb.mean():.4f} std={text_emb.std():.4f}")
# std < 0.01 → embeddings are collapsed → broken encoder or wrong weights
# std > 10.0 → embeddings are exploding → wrong normalization

Verify with ablation

# Generate with real text
output_text = denoise(latent, text_emb=real_embeddings)
# Generate with zeros
output_zero = denoise(latent, text_emb=mx.zeros_like(real_embeddings))
# Compare
text_influence = np.mean(np.abs(output_text - output_zero))
print(f"Text influence: {text_influence:.4f}")  # Should be > 0 (typically 30-60% of output)

Text preprocessing must match exactly

Bug (Wan2.2): Patchy-blurry output from wrong negative prompt tokenization.

The official Wan2.2 tokenizer applies ftfy.fix_text + html.unescape + whitespace normalization before tokenization. Without this, fullwidth Chinese commas (U+FF0C) tokenize differently from ASCII commas (U+002C), causing 27 different token IDs in the negative prompt. This made CFG's unconditional prediction wrong.

Fix: Apply the same text cleaning pipeline as the reference:

import ftfy
import html
import re

def clean_text(text):
    text = ftfy.fix_text(text)
    text = html.unescape(text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

T5 encoder precision

Bug (Wan2.2): Quality degradation from bfloat16 T5 attention.

T5 uses no scaling in attention (no 1/sqrt(d) factor), so attention logits can be very large. bfloat16 softmax loses significant precision across 24 encoder layers.

Fix: Compute T5 QK^T and softmax in float32. The T5 encoder only runs once per generation, so the performance impact is negligible.

Dual-model text embeddings

Bug (Wan2.2 I2V-14B): Low/high noise models have different text_embedding weights (~42% relative difference). Using one model's embeddings for both caused incorrect text conditioning for the high-noise model that handles critical early denoising steps.

Fix: Compute separate text embeddings for each model in dual-model setups.


9. Position Encodings (RoPE)

Multi-scale consistency

In pyramid/multi-resolution models, RoPE must be computed consistently across scales. If the model operates at 1/4 resolution in an early stage, the position grid must reflect the actual spatial dimensions, not the final target dimensions.

History vs current chunk

When conditioning on history from a previous chunk, the position encoding for history frames must match what the model saw during training. Mismatches between history and current-chunk position encodings can cause subtle spatial distortions that compound across chunks.

Factorized RoPE

3D video models often use factorized RoPE (separate temporal, height, width frequencies). Verify each axis independently:

# Compare temporal frequencies
assert np.allclose(mlx_rope_t, ref_rope_t, atol=1e-5)
# Compare spatial frequencies
assert np.allclose(mlx_rope_h, ref_rope_h, atol=1e-5)
assert np.allclose(mlx_rope_w, ref_rope_w, atol=1e-5)

Per-axis frequency construction

Bug (Wan2.2): Grey/artifact-filled output from wrong frequency distribution.

The reference uses three separate rope_params() calls with different dimension normalizations (e.g., 44, 42, 42 for Wan) so each axis gets its own full frequency range. Consolidating into a single rope_params(head_dim) call and splitting gave height frequencies starting at 0.042 and width at 0.002 (should be 1.0 for both).

Fix (and subsequent revert): This bug was introduced as a "fix" for a previous RoPE issue, then had to be reverted. The lesson: RoPE changes have far-reaching effects. Always verify with actual generation, not just numerical comparison of frequencies.

Lesson: Read the reference's frequency construction very carefully. Don't "simplify" three separate calls into one unless you verify the frequency distribution matches exactly.


10. Multi-Stage / Pyramid Pipelines

Each stage is a potential failure point

Pyramid pipelines (generate at low res, upsample, refine at high res) multiply the number of things that can go wrong:

  • Downsampling method (bilinear vs area) must match reference
  • Energy compensation factors (e.g., ×2 after bilinear downsample) must be present
  • Alpha/beta noise mixing coefficients are stage-dependent
  • Frame indices and history resolution change per stage

Test single-stage first

If the model works at full resolution for a single stage but fails in the pyramid, the bug is in stage orchestration — typically in how latents are passed between stages or how position encodings adapt to different resolutions.

Integration bugs are the hardest

We verified every Helios component matched the reference individually, but the pyramid still produced uniform color. The bug was in dtype handling during stage transitions. Integration bugs only appear when components interact.


11. Common Symptoms → Root Causes

Symptom Likely Root Causes
Pure noise output Wrong sigma schedule, broken text conditioning, incorrect weight mapping
Uniform color Model predicting noise back; text embeddings collapsed; wrong timestep format
Progressive zoom/shrink bfloat16 residuals truncating high-freq detail; RoPE mismatch across chunks
Brightness jumps at boundaries VAE causal warmup; cross-fade blending misaligned content
Color drift across chunks Dtype in scheduler step; history normalization missing
Blur at boundaries Cross-fade enabled; latent blending; wrong VAE decode order
Grid/checker patterns Patchify channel ordering bug; latent blend artifacts; reshape axis error
Green/magenta tint VAE weight key mismatch; wrong denormalization constants; cv2 YUV color matrix
Mean drift across steps bfloat16 accumulation; wrong scheduler formula; missing energy compensation
Garbled/scrambled output Silent weight drops (underscore prefix, wrong key mapping); RMS vs L2 norm
Greenish-yellow constant Scheduler boundary condition (log(0) not returning -inf); x0 overscaling
~2x slower than expected Float32 promotion cascade from single mistyped intermediate
Extra output frames Causal padding producing extra temporal frames; missing first_chunk trim
Grey/artifact output RoPE frequency construction wrong (per-axis vs single-call)
Patchy-blurry with CFG Text preprocessing mismatch (fullwidth vs ASCII chars → wrong tokenization)
I2V temporal mismatch Non-chunked VAE encoding vs reference's chunked encoding with temporal cache

12. Verification Checklist

Use this checklist when porting a new diffusion video model:

Model

  • Weight conversion: all keys mapped, cosine similarity > 0.9999
  • No silent weight drops (log loaded vs expected keys)
  • Single forward pass matches reference (cos_sim > 0.999)
  • Residual connections use float32 accumulation
  • Attention computation matches reference precision
  • Modulation/gate vectors in float32 (if reference uses autocast)
  • No underscore-prefixed module attributes (MLX ignores them)

Scheduler

  • Sigma values match reference at every step (diff < 1e-6)
  • Timestep format correct (int vs float, scale factor)
  • Dynamic shifting formula copied exactly
  • Step function returns correct dtype (float32)
  • Boundary conditions: lambda(-inf) at sigma=1, lambda(+inf) at sigma=0
  • Higher-order coefficients computed (not hardcoded) for UniPC/DPM++

Text Encoder

  • Embedding statistics reasonable (0.01 < std < 10)
  • Text influence > 0 (ablation test)
  • Tokenization matches (special tokens, padding, max length)
  • Text preprocessing matches (ftfy, html unescape, whitespace normalization)
  • T5/CLIP attention precision (float32 softmax if no 1/sqrt(d) scaling)
  • Separate embeddings for dual-model setups (if applicable)

VAE

  • Denormalization constants match training pipeline
  • Per-chunk decoding (not concatenated)
  • Temporal frame count correct (account for causal padding)
  • Weight keys mapped correctly (encoder vs decoder)
  • Weights stored/loaded in float32 (not bfloat16)
  • Temporal downsample/upsample order matches reference
  • RMS_norm vs L2_norm: check actual implementation, not class name
  • Chunked encoding for I2V (if applicable)
  • Reshape axis ordering correct ([B,C,T,H,W] → transpose before reshape)

Pipeline Orchestration

  • Position encodings consistent across stages/chunks
  • History slicing and conditioning correct
  • Noise generation matches (distribution, correlation structure)
  • Multi-chunk output visually consistent (no progressive degradation)
  • Dimension auto-alignment (divisible by patch_size × vae_stride)
  • Dtype-aware padding (mx.zeros with explicit dtype)

Output

  • Frame count matches expected (account for warmup/trim)
  • FPS correct
  • Color range [0, 255] uint8 for video
  • No first-frame duplication artifacts
  • Video codec correct (imageio/libx264 preferred over cv2/mp4v on macOS)

Performance

  • No float32 promotion cascades (check with profiler)
  • Using mx.fast kernels (rms_norm, layer_norm, sdpa)
  • Time embedding computed once per sample (not per position)
  • Memory cleanup (delete temporaries before mx.eval)

13. Diagnostic Tools

General video diagnostics (scripts/video/)

Script Purpose
compare_videos.py PSNR, SSIM, temporal coherence, color fidelity between two videos
video_quality.py Sharpness, stability, defect detection, chunk boundary analysis
# Quick quality check
python scripts/video/video_quality.py output.mp4 --chunk-size 32

# Compare against reference
python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4

Model-specific diagnostics (scripts/helios/)

Script Purpose
analyze_boundaries.py Detailed boundary quality metrics for Helios
run_reference.py Run PyTorch reference on MPS
compare_pipelines.py Compare scheduler/pipeline mechanics
compare_models.py Cross-framework model output comparison

Inline debugging pattern

Add temporary debug output to the diffusion loop:

for i, sigma in enumerate(sigmas):
    flow = model(latent, sigma, text_emb)
    latent = scheduler.step(latent, flow, sigma, sigma_next)

    # Debug: track statistics
    print(f"[step {i}] sigma={sigma:.4f} "
          f"latent: mean={latent.mean():.6f} std={latent.std():.6f} "
          f"flow: mean={flow.mean():.6f} std={flow.std():.6f}")

    # Debug: save for cross-framework comparison
    if os.environ.get("DEBUG"):
        mx.save(f"/tmp/debug_step_{i}.npz", {
            "latent": latent, "flow": flow, "sigma": mx.array(sigma)
        })

Key Takeaways

  1. Precision is the #1 bug source — bfloat16 residuals, scheduler math, type promotion, modulation vectors. Copy the reference's .float() and autocast zones.

  2. Don't add what the reference doesn't have — cross-fade, overlap decode, temporal blending. If the reference works without it, you probably have a bug elsewhere.

  3. Silent failures are the hardest bugs — underscore-prefixed weights, strict=False weight loading, wrong normalization class names. Always verify weight load counts and output statistics.

  4. Component isolation → integration testing — verify each part matches, then debug their interaction.

  5. Statistical comparison beats visual inspection — mean drift, contrast ratios, and cosine similarity catch bugs before they're visible.

  6. Autoregressive errors compound — a 1% error per chunk becomes 10% by chunk 10. Fix precision first, add corrections second.

  7. MLX has unique pitfalls — underscore attribute names, reshape axis ordering, dtype-unaware padding, and float32 promotion cascades. Know your framework.

  8. Text preprocessing matters — Unicode normalization, fullwidth chars, HTML entities. A single mismatched comma can break CFG guidance.

  9. VAE is deceptively complex — causal padding, temporal frame counts, chunked vs batch processing, norm implementations. Budget significant debugging time for VAE.