diff --git a/docs/PORTING-GUIDE.md b/docs/PORTING-GUIDE.md new file mode 100644 index 0000000..5c44f94 --- /dev/null +++ b/docs/PORTING-GUIDE.md @@ -0,0 +1,911 @@ +# Porting Diffusion Video Models to MLX: Lessons Learned + +A practical guide distilled from porting Wan2.1/2.2 (1.3B–14B) and Helios 14B DiT +video generation models from PyTorch to MLX on Apple Silicon. These lessons apply +broadly to any diffusion-based video (or image) model port. + +--- + +## Table of Contents + +1. [Debugging Methodology](#1-debugging-methodology) +2. [Precision & Dtype Pitfalls](#2-precision--dtype-pitfalls) +3. [MLX-Specific Gotchas](#3-mlx-specific-gotchas) +4. [Autoregressive Chunk Boundaries](#4-autoregressive-chunk-boundaries) +5. [VAE Decoder Artifacts](#5-vae-decoder-artifacts) +6. [Scheduler & Timestep Issues](#6-scheduler--timestep-issues) +7. [Weight Conversion](#7-weight-conversion) +8. [Text Conditioning Failures](#8-text-conditioning-failures) +9. [Position Encodings (RoPE)](#9-position-encodings-rope) +10. [Multi-Stage / Pyramid Pipelines](#10-multi-stage--pyramid-pipelines) +11. [Common Symptoms → Root Causes](#11-common-symptoms--root-causes) +12. [Verification Checklist](#12-verification-checklist) +13. [Diagnostic Tools](#13-diagnostic-tools) + +--- + +## 1. Debugging Methodology + +### Component isolation first + +Never debug the full pipeline. Test each component in isolation: + +1. **Text encoder** — Does it produce embeddings with reasonable statistics? (std > 0.01) +2. **Scheduler** — Do sigma/timestep values match the reference exactly? +3. **Transformer** — Does a single forward pass match the reference? (cosine similarity > 0.999) +4. **VAE decoder** — Feed reference latents into your VAE. Does the output look correct? + +If every component matches individually but the pipeline fails, the bug is in +**orchestration** — how components are wired together. + +### Statistical fingerprinting + +Track per-step statistics through the diffusion loop: + +```python +# After each denoising step +print(f"step {i}: mean={latent.mean():.6f} std={latent.std():.6f} " + f"min={latent.min():.4f} max={latent.max():.4f}") +``` + +**What to look for:** +- **Progressive mean drift** (e.g., -0.002 → -0.040 → -0.123) signals accumulating errors +- **Collapsing std** (std dropping toward 0) signals broken conditioning or wrong noise schedule +- **Exploding values** signal wrong sigma scaling or scheduler formula + +### Cross-framework numerical comparison + +The most powerful debugging tool: save intermediate tensors from your MLX pipeline, +feed them to the PyTorch reference, compare outputs. + +```python +# In MLX pipeline, save inputs before transformer call +mx.save("debug_inputs.npz", {"latent": latent, "timestep": t, "text_emb": text_emb}) + +# In PyTorch script, load and compare +inputs = np.load("debug_inputs.npz") +mlx_out = np.load("debug_output.npz")["flow"] +pt_out = reference_model(torch.from_numpy(inputs["latent"]), ...) +cos_sim = F.cosine_similarity(pt_out.flatten(), torch.from_numpy(mlx_out).flatten(), dim=0) +# cos_sim > 0.999 = model is correct; bug is elsewhere +# cos_sim < 0.99 = model has a bug; compare per-layer +``` + +### Ablation testing + +When a pipeline has multiple "fixes" or features, disable them one at a time: + +- **Frozen history**: Fix history to the same value for all chunks → proves whether + history propagation is the source of drift/zoom +- **Single chunk**: Generate only 1 chunk → isolates per-chunk quality from + multi-chunk interaction bugs +- **Disable post-processing**: Remove cross-fade, blending, corrections → reveals + what the raw model output looks like + +### Use reference on same hardware + +Run the PyTorch reference on the same device (MPS for Apple Silicon). CUDA and MPS +produce different numerical results due to different float handling. Comparing your +MLX output against a CUDA reference adds noise to the comparison. + +```python +# MPS may not support float64 — patch the reference: +original_linspace = torch.linspace +def patched_linspace(*args, **kwargs): + kwargs.pop("dtype", None) + return original_linspace(*args, dtype=torch.float32, **kwargs) +torch.linspace = patched_linspace +``` + +--- + +## 2. Precision & Dtype Pitfalls + +### The #1 source of subtle bugs + +Precision issues caused the most insidious bugs in our port. They don't cause +crashes — they cause progressive quality degradation that's hard to attribute. + +### Residual connections MUST be float32 + +**Bug**: Progressive zoom/shrinking across autoregressive chunks. + +**Root cause**: Residual additions (`x = x + attn_out`) in bfloat16. With 7-bit +mantissa, high-frequency spatial detail is systematically truncated. Over 144 +residual ops × 6+ model calls per chunk, detail is progressively smoothed away. + +**Fix**: Promote to float32 for the addition: +```python +# BAD — bfloat16 accumulation +x = x + attn_out + +# GOOD — match reference's .float() pattern +x = (x.astype(mx.float32) + attn_out).astype(weight_dtype) +``` + +**Rule**: If the reference uses `.float()` anywhere, copy that pattern exactly. It's +there for a reason, even if a quick test seems to work without it. + +### Scheduler computations need high precision + +Diffusion schedulers involve: +- `x0 = xt - sigma * flow` — catastrophic cancellation near sigma ≈ 1 +- `log(sigma)` and `exp()` — sensitive to small precision differences + +Some references use float64 for these computations. MLX GPU doesn't support float64, +so use float32 and accept small numerical differences, but **never** use bfloat16 +for scheduler math. + +### Dtype propagation is invisible + +Track dtype through your pipeline. A single bfloat16 intermediate can silently +downcast everything downstream: + +```python +# This looks harmless but if model output is bfloat16: +result = noise - sigma * model_output # result is bfloat16! + +# Fix: explicit cast +result = (noise.astype(mx.float32) - sigma * model_output.astype(mx.float32)) +``` + +### Type promotion rules differ across frameworks + +- PyTorch: bfloat16 + float32 → float32 +- MLX: bfloat16 + float32 → float32 (same, but verify) +- NumPy: no bfloat16 support + +Always check what your framework does and match the reference's implicit promotations. + +### Float32 for VAE decoding + +**Bug** (Wan2.2): VAE decode in bfloat16 produced visibly worse quality than reference. + +Official Wan2.2 runs VAE decode in `torch.float` (float32), but our converted weights +were bfloat16. The VAE has many sequential layers where precision loss compounds. + +**Fix**: Upcast VAE weights to float32 at load time. The VAE runs once per generation, +so the performance impact is negligible compared to the transformer. + +### Modulation/gate vectors need float32 + +**Bug** (Wan2.2): Quality degradation from bfloat16 modulation across 30 blocks × 50 steps. + +The official Wan2.2 explicitly uses `torch.amp.autocast('cuda', dtype=torch.float32)` +for time embeddings, modulation parameters, norm outputs before modulation, and gate ops. + +**Fix**: Keep modulation in float32, cast to working dtype only when applying to the +hidden state: +```python +# Modulation computed in float32 +e0 = self.modulation(time_emb) # float32 +scale, shift, gate = e0.split(3, axis=-1) + +# Cast to bfloat16 only for the matmul with hidden state +x = (x * (1 + scale.astype(x.dtype)) + shift.astype(x.dtype)) +``` + +### Map PyTorch autocast zones precisely + +PyTorch models use nested `torch.amp.autocast` scopes to switch precision. Map these +exactly: +- **Outer scope** (`bfloat16`): attention QKV projections, FFN matmuls +- **Inner scope** (`float32`): modulation, gates, norms, RoPE +- **Residual stream**: float32 (the "backbone" between blocks) + +```python +# Wan2.2 dtype flow (matches official): +# Modulation/gates: float32 (explicit) +# QKV/FFN linear projections: bfloat16 (weight dtype) +# RoPE: float32 (official uses float64, MLX lacks float64) +# Attention Q/K: cast back to bfloat16 after RoPE +# Residual stream: float32 +``` + +### Float32 promotion cascades kill performance + +**Bug** (Wan2.2): ~2x slowdown from accidental float32 promotion. + +A single float32 tensor (e.g., time embedding) flowing into bfloat16 operations +promotes the entire computation graph to float32. In Wan2.2: +- Time embedding MLP output (float32) fed into transformer → all layers float32 +- RoPE frequencies (float32) applied to Q/K → all attention float32 + +**Fix**: Cast intermediate results to model dtype at promotion boundaries: +```python +# After time embedding MLP (float32), cast before feeding to transformer +time_emb = time_mlp(t).astype(model_dtype) + +# After RoPE (float32), cast Q/K back to attention dtype +q = rope_apply(q, freqs).astype(v.dtype) +``` + +--- + +## 3. MLX-Specific Gotchas + +### Underscore-prefixed attributes are invisible + +**Bug** (Wan2.2): 87 of 110 VAE weights silently dropped during loading. + +MLX's `nn.Module.parameters()` and `nn.Module.load_weights()` **skip** attributes +whose names start with underscore. If you name a layer `self._layer_0`, its weights +will never be loaded or saved. + +```python +# BAD — weights silently ignored +self._layer_0 = nn.Linear(...) # nn.Module skips _prefixed attrs + +# GOOD +self.layer_0 = nn.Linear(...) +``` + +This is especially insidious because there's no error — the model loads, runs, and +produces output. The output is just garbage because most weights are random. + +### nn.Sequential indexing vs named children + +PyTorch's `nn.Sequential` uses integer indices (`sequential.0.weight`), while MLX's +module hierarchy uses named attributes. When mirroring a PyTorch module structure, +you need explicit key sanitization: + +```python +def sanitize_key(key): + # PyTorch: "decoder.middle.0.residual.1.weight" + # MLX: "decoder.middle.layer_0.residual.layer_1.weight" + key = re.sub(r'\.(\d+)', lambda m: f'.layer_{m.group(1)}', key) + return key +``` + +### Reshape axis ordering differs from PyTorch + +**Bug** (Wan2.2): Green checkerboard pattern from VAE attention. + +`[B,C,T,H,W]` cannot be directly reshaped to `[BT,C,H,W]` because in memory C +comes before T. PyTorch's `reshape` works because it handles non-contiguous tensors. +MLX requires explicit transpose first: + +```python +# BAD — mixes channels with time +x = x.reshape(B*T, C, H, W) # Corrupts spatial layout + +# GOOD — make B,T adjacent first +x = x.transpose(0, 2, 1, 3, 4) # [B,T,C,H,W] +x = x.reshape(B*T, C, H, W) # Now correct +``` + +### Patchify channel ordering + +**Bug** (Wan2.2): Solid green video output from wrong patchify order. + +When converting a Conv3d patchify to a manual reshape+linear, the dimension ordering +in the reshape must match the Conv3d weight layout. Conv3d expects `[C, pt, ph, pw]` +(channels slowest), but a naive reshape produces `[pt, ph, pw, C]` (channels fastest): + +```python +# BAD — channel scrambling +patches = x.reshape(B, F', H', W', pt, ph, pw, C) + +# GOOD — match Conv3d weight layout +patches = x.reshape(B, F', pt, H', ph, W', pw, C) +patches = patches.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, F', H', W', C, pt, ph, pw] +``` + +Verify numerically: the fixed version should match Conv3d output to ~1e-6. + +### mx.zeros / padding inherits dtype + +Use dtype-aware `mx.zeros` for padding and concatenation to avoid promotion: + +```python +# BAD — default float32 padding promotes bfloat16 input +pad = mx.zeros((B, pad_len, C)) # float32! +x = mx.concatenate([pad, x], axis=1) # x promoted to float32 + +# GOOD — match input dtype +pad = mx.zeros((B, pad_len, C), dtype=x.dtype) +x = mx.concatenate([pad, x], axis=1) # stays bfloat16 +``` + +### Use mx.fast kernels + +Replace manual implementations with fused MLX kernels where possible: + +```python +# Manual RMS norm → mx.fast.rms_norm +# Manual LayerNorm → mx.fast.layer_norm +# Manual attention → mx.fast.scaled_dot_product_attention +``` + +These are faster and handle precision internally. + +--- + +## 4. Autoregressive Chunk Boundaries + +For models that generate long videos by autoregressively extending chunks (Helios, +CogVideoX, etc.), chunk boundaries are the primary source of visual artifacts. + +### Don't add post-processing the reference doesn't have + +**Bug**: Added pixel cross-fade to smooth boundaries → caused 40% sharpness drop. + +The reference pipeline used **no cross-fade at all**. The first frame of each new +chunk is intentionally a sharp reconstruction conditioned on history. Blending it with +the previous chunk's tail (which has different content) creates blur. + +**Rule**: Before adding smoothing/blending, verify the reference doesn't do it. +Reference simplicity is usually correct. + +### First-frame artifacts are common + +The first pixel frame of each non-first chunk is typically a distorted reconstruction +of the conditioning frame. In many models, this is expected behavior: + +- **Fix**: Drop the first frame from each chunk +- **Verify frame math**: If 33 raw frames at 16fps → drop 1 → 32 frames = exactly 2 seconds + +### History conditioning errors compound + +Small errors in how history is prepared, sliced, patchified, or position-encoded +will compound across chunks. The error is invisible in chunk 1, small in chunk 2, +and catastrophic by chunk 5. + +**Debug strategy**: Generate with frozen history (same history for every chunk). +If the artifact disappears, the bug is in history handling. + +--- + +## 5. VAE Decoder Artifacts + +### Causal temporal convolutions cause boundary warmup + +Video VAEs (WanVAE, CogVideoX-VAE) use causal temporal convolutions. When decoding +each chunk independently, the first few frames lack temporal context (only zero +padding), causing: + +- **~7% contrast drop** in first frames of each chunk +- **Spatial brightness redistribution** (face darkens, background brightens) + +This is inherent to the architecture. The reference has the same effect but at +lower magnitude. + +### Post-processing to fix VAE warmup + +Two-stage correction applied to first N frames of each non-first chunk: + +```python +# Stage 1: Spatially-varying brightness correction +# Downsample reference (previous chunk's last frame) and current frame +ref_small = cv2.resize(ref_frame, (w//16, h//16), interpolation=cv2.INTER_AREA) +cur_small = cv2.resize(cur_frame, (w//16, h//16), interpolation=cv2.INTER_AREA) +diff_small = ref_small - cur_small +diff_full = cv2.resize(diff_small, (w, h), interpolation=cv2.INTER_LINEAR) +corrected = cur_frame + ramp * diff_full # ramp: 1.0 → 0.0 over N frames + +# Stage 2: Per-channel contrast matching +for c in range(3): + ref_std = np.std(ref_frame[:,:,c]) + cur_std = np.std(corrected[:,:,c]) + scale = 1.0 + ramp * (ref_std / (cur_std + 1e-6) - 1.0) + corrected[:,:,c] = (corrected[:,:,c] - mean) * scale + mean +``` + +### VAE overlap decode does NOT work + +**Attempted**: Prepend previous chunk's last latent frames to give the decoder +temporal context. + +**Result**: Made things **worse** (22% contrast drop vs 7%). The causal convolutions +see conflicting content from different chunks and create larger artifacts than +zero-padding. + +**Lesson**: Overlap only works when tiles contain the same content from the same +denoising process (e.g., spatial tiling). It fails for temporal chunks with +different content. + +### Per-chunk VAE decoding is correct + +Decode each chunk's latents independently, not concatenated. Concatenating all chunks +and decoding together lets boundary discontinuities propagate through temporal +convolutions, creating worse artifacts. + +### First-frame quality: causal padding strategies + +Multiple approaches were tried for the first-frame quality issue in Wan VAE: + +| Approach | Result | +|----------|--------| +| Zero padding (default) | First ~4 frames degraded, but matches training | +| Replicate padding | Fixes artifacts but causes color intensity bias (conv applies all kernel weights to same value) | +| Warmup frame prepend | Helps motion but warmup frame itself has artifacts | +| Mirror-reflect warmup | Best compromise — varied context without zeros, no intensity bias | + +**Lesson**: Don't assume "replicate padding is better than zero padding." The model +was trained with zero padding; changing it shifts the gain. Instead, prepend warmup +frames and trim them after decoding. + +### RMS_norm vs L2 normalization + +**Bug** (Wan2.2): Garbled output from incorrect normalization. + +A PyTorch class named `RMS_norm` actually uses `F.normalize` (L2 norm: `x / ||x||_2`), +not RMS normalization (`x / sqrt(mean(x²))`). The difference is a factor of `sqrt(C)`, +causing values to explode through the decoder. + +**Lesson**: Don't trust class names — read the actual implementation. + +### Temporal frame count: causal boundary effects + +**Bug** (Wan2.2): VAE produced 12 frames instead of 9 for a 9-frame input. + +PyTorch reference processes frames one-by-one with caching, skipping temporal conv for +the first chunk. All-at-once decoding produces extra frames from zero-padded causal +context. + +**Fix**: Use `first_chunk=True` flag to trim causal boundary frames, matching the +reference's chunked behavior. + +### Chunked VAE encoding for I2V + +**Bug** (Wan2.2 I2V-14B): Incorrect latents from non-chunked encoding. + +Non-chunked encoding with causal zero-padding produces incorrect latents because +temporal features don't propagate correctly without caching. The reference uses chunked +encoding (1+4+4+... frames) with persistent temporal cache. + +**Fix**: Implement chunked encoding with `feat_cache` propagation through CausalConv3d, +ResidualBlock, and Resample layers. + +--- + +## 6. Scheduler & Timestep Issues + +### Copy formulas exactly + +Even small differences in scheduler formulas compound over many steps: + +```python +# Dynamic time shifting — reference uses specific formula +mu = 0.5 + shift * 0.5 # NOT shift * 0.6 or any other constant + +# Euler step +x_next = x + (sigma_next - sigma) * flow # order matters: next - current +``` + +### Verify sigma schedules numerically + +Print and compare sigma values at each step: + +```python +# Reference +sigmas_ref = [1.0, 0.99375, 0.9875, ...] + +# Your implementation +sigmas = scheduler.get_sigmas(steps) +for i, (r, m) in enumerate(zip(sigmas_ref, sigmas)): + assert abs(r - m) < 1e-6, f"Step {i}: ref={r}, mlx={m}" +``` + +### Timestep embedding precision + +Integer vs float timesteps matter. Some models expect `timestep=999` (int), others +expect `timestep=0.999` (float). Wrong type can silently produce wrong embeddings +with reasonable-looking but incorrect statistics. + +### Boundary conditions: ±inf at sigma endpoints + +**Bug** (Wan2.2): Greenish/yellow constant output from DPM++/UniPC schedulers. + +The `lambda(sigma)` function must return `-inf` at `sigma=1.0` (pure noise) and `+inf` +at `sigma=0.0` (clean signal). Our implementation returned `0.0`, causing massive x0 +overscaling on the first denoising step. + +PyTorch naturally computes `torch.log(0) = -inf`, and `math.expm1(-inf) = -1.0` +handles the formulas correctly. Reproduce this behavior explicitly: + +```python +def _lambda(self, sigma): + if sigma >= 1.0: + return float('-inf') + if sigma <= 0.0: + return float('inf') + return -math.log(sigma / (1 - sigma)) +``` + +### UniPC corrector coefficients + +**Bug** (Wan2.2): Accumulated artifacts across 47+ steps from wrong polynomial weights. + +The UniPC corrector must compute `rhos_c` via `linalg.solve` for order ≥ 2. Hardcoded +`0.5` was 7× too large for the history weight (actual: ~0.08), causing massive +overweighting of history corrections. + +--- + +## 7. Weight Conversion + +### Always verify statistically + +After converting weights from PyTorch to MLX format: + +```python +for name in mlx_weights: + pt = pytorch_weights[map_name(name)] + mx_val = np.array(mlx_weights[name]) + pt_val = pt.numpy() + cos_sim = np.dot(mx_val.flat, pt_val.flat) / ( + np.linalg.norm(mx_val) * np.linalg.norm(pt_val) + 1e-10 + ) + if cos_sim < 0.9999: + print(f"MISMATCH: {name} cos_sim={cos_sim:.6f}") +``` + +### Conv3d → Linear reshaping + +When converting 3D convolutions to linear layers (common for MLX which prefers +linear ops), the flattening order must match: + +```python +# PyTorch Conv3d weight: (out_ch, in_ch, kT, kH, kW) +# Flatten to Linear: (out_ch, in_ch * kT * kH * kW) +# The reshape order MUST match how the input is patchified +``` + +### Sanitization functions + +Write explicit weight sanitization that maps reference key names to your key names. +Don't rely on automatic matching — key naming conventions differ between frameworks. + +### Module structure must mirror reference for direct loading + +**Bug** (Wan2.2): Rewrote entire VAE module hierarchy to match PyTorch `nn.Sequential` +structure. ResidualBlock needed `None` gaps at specific indices to match the original +`nn.Sequential(RMSNorm, SiLU, Conv3d, ...)` indexing. + +When possible, structure your modules to accept reference weights directly without +sanitization. This eliminates an entire class of bugs. + +### Save VAE weights in float32 + +Even if the model uses bfloat16 for the transformer, save VAE weights in float32. +bfloat16 → float32 roundtrip loses precision that cannot be recovered by load-time +upcast. + +### Temporal downsample/upsample order + +**Bug** (Wan2.2): `temporal_downsample=[True, True, False]` but reference uses +`[False, True, True]`. Stage 0 created a `time_conv` with random weights (no matching +file key), and Stage 2 missed its `time_conv` (weights silently dropped). + +Always verify boolean flags for each stage by inspecting the actual weight file keys. + +### Silent weight drops are the worst bugs + +When `load_weights()` with `strict=False` silently skips keys that don't match, you +get a model with random weights for those layers. This produces output that looks +"almost right" but is subtly wrong. Always log which keys were loaded vs skipped: + +```python +loaded_keys = set() +for key, value in weights: + if key in model_params: + loaded_keys.add(key) +# Check for missing +expected = set(model_params.keys()) +missing = expected - loaded_keys +if missing: + print(f"WARNING: {len(missing)} weights not loaded: {list(missing)[:5]}...") +``` + +--- + +## 8. Text Conditioning Failures + +### Symptom: model predicts noise back to itself + +If the model output correlates > 0.8 with its input noise, text conditioning is +likely broken. The model has learned nothing from the prompt and is just returning +its input. + +### Check embedding statistics + +```python +text_emb = text_encoder(prompt) +print(f"text_emb: mean={text_emb.mean():.4f} std={text_emb.std():.4f}") +# std < 0.01 → embeddings are collapsed → broken encoder or wrong weights +# std > 10.0 → embeddings are exploding → wrong normalization +``` + +### Verify with ablation + +```python +# Generate with real text +output_text = denoise(latent, text_emb=real_embeddings) +# Generate with zeros +output_zero = denoise(latent, text_emb=mx.zeros_like(real_embeddings)) +# Compare +text_influence = np.mean(np.abs(output_text - output_zero)) +print(f"Text influence: {text_influence:.4f}") # Should be > 0 (typically 30-60% of output) +``` + +### Text preprocessing must match exactly + +**Bug** (Wan2.2): Patchy-blurry output from wrong negative prompt tokenization. + +The official Wan2.2 tokenizer applies `ftfy.fix_text` + `html.unescape` + whitespace +normalization before tokenization. Without this, fullwidth Chinese commas (U+FF0C) +tokenize differently from ASCII commas (U+002C), causing **27 different token IDs** +in the negative prompt. This made CFG's unconditional prediction wrong. + +**Fix**: Apply the same text cleaning pipeline as the reference: +```python +import ftfy +import html +import re + +def clean_text(text): + text = ftfy.fix_text(text) + text = html.unescape(text) + text = re.sub(r'\s+', ' ', text).strip() + return text +``` + +### T5 encoder precision + +**Bug** (Wan2.2): Quality degradation from bfloat16 T5 attention. + +T5 uses **no scaling** in attention (no `1/sqrt(d)` factor), so attention logits can +be very large. bfloat16 softmax loses significant precision across 24 encoder layers. + +**Fix**: Compute T5 QK^T and softmax in float32. The T5 encoder only runs once per +generation, so the performance impact is negligible. + +### Dual-model text embeddings + +**Bug** (Wan2.2 I2V-14B): Low/high noise models have different `text_embedding` weights +(~42% relative difference). Using one model's embeddings for both caused incorrect +text conditioning for the high-noise model that handles critical early denoising steps. + +**Fix**: Compute separate text embeddings for each model in dual-model setups. + +--- + +## 9. Position Encodings (RoPE) + +### Multi-scale consistency + +In pyramid/multi-resolution models, RoPE must be computed consistently across scales. +If the model operates at 1/4 resolution in an early stage, the position grid must +reflect the actual spatial dimensions, not the final target dimensions. + +### History vs current chunk + +When conditioning on history from a previous chunk, the position encoding for +history frames must match what the model saw during training. Mismatches between +history and current-chunk position encodings can cause subtle spatial distortions +that compound across chunks. + +### Factorized RoPE + +3D video models often use factorized RoPE (separate temporal, height, width +frequencies). Verify each axis independently: + +```python +# Compare temporal frequencies +assert np.allclose(mlx_rope_t, ref_rope_t, atol=1e-5) +# Compare spatial frequencies +assert np.allclose(mlx_rope_h, ref_rope_h, atol=1e-5) +assert np.allclose(mlx_rope_w, ref_rope_w, atol=1e-5) +``` + +### Per-axis frequency construction + +**Bug** (Wan2.2): Grey/artifact-filled output from wrong frequency distribution. + +The reference uses three separate `rope_params()` calls with different dimension +normalizations (e.g., 44, 42, 42 for Wan) so each axis gets its own full frequency +range. Consolidating into a single `rope_params(head_dim)` call and splitting gave +height frequencies starting at 0.042 and width at 0.002 (should be 1.0 for both). + +**Fix** (and subsequent revert): This bug was introduced as a "fix" for a previous +RoPE issue, then had to be reverted. The lesson: RoPE changes have far-reaching effects. +Always verify with actual generation, not just numerical comparison of frequencies. + +**Lesson**: Read the reference's frequency construction very carefully. Don't +"simplify" three separate calls into one unless you verify the frequency distribution +matches exactly. + +--- + +## 10. Multi-Stage / Pyramid Pipelines + +### Each stage is a potential failure point + +Pyramid pipelines (generate at low res, upsample, refine at high res) multiply the +number of things that can go wrong: + +- Downsampling method (bilinear vs area) must match reference +- Energy compensation factors (e.g., ×2 after bilinear downsample) must be present +- Alpha/beta noise mixing coefficients are stage-dependent +- Frame indices and history resolution change per stage + +### Test single-stage first + +If the model works at full resolution for a single stage but fails in the pyramid, +the bug is in stage orchestration — typically in how latents are passed between +stages or how position encodings adapt to different resolutions. + +### Integration bugs are the hardest + +We verified every Helios component matched the reference individually, but the +pyramid still produced uniform color. The bug was in dtype handling during stage +transitions. Integration bugs only appear when components interact. + +--- + +## 11. Common Symptoms → Root Causes + +| Symptom | Likely Root Causes | +|---------|-------------------| +| **Pure noise output** | Wrong sigma schedule, broken text conditioning, incorrect weight mapping | +| **Uniform color** | Model predicting noise back; text embeddings collapsed; wrong timestep format | +| **Progressive zoom/shrink** | bfloat16 residuals truncating high-freq detail; RoPE mismatch across chunks | +| **Brightness jumps at boundaries** | VAE causal warmup; cross-fade blending misaligned content | +| **Color drift across chunks** | Dtype in scheduler step; history normalization missing | +| **Blur at boundaries** | Cross-fade enabled; latent blending; wrong VAE decode order | +| **Grid/checker patterns** | Patchify channel ordering bug; latent blend artifacts; reshape axis error | +| **Green/magenta tint** | VAE weight key mismatch; wrong denormalization constants; cv2 YUV color matrix | +| **Mean drift across steps** | bfloat16 accumulation; wrong scheduler formula; missing energy compensation | +| **Garbled/scrambled output** | Silent weight drops (underscore prefix, wrong key mapping); RMS vs L2 norm | +| **Greenish-yellow constant** | Scheduler boundary condition (log(0) not returning -inf); x0 overscaling | +| **~2x slower than expected** | Float32 promotion cascade from single mistyped intermediate | +| **Extra output frames** | Causal padding producing extra temporal frames; missing `first_chunk` trim | +| **Grey/artifact output** | RoPE frequency construction wrong (per-axis vs single-call) | +| **Patchy-blurry with CFG** | Text preprocessing mismatch (fullwidth vs ASCII chars → wrong tokenization) | +| **I2V temporal mismatch** | Non-chunked VAE encoding vs reference's chunked encoding with temporal cache | + +--- + +## 12. Verification Checklist + +Use this checklist when porting a new diffusion video model: + +### Model +- [ ] Weight conversion: all keys mapped, cosine similarity > 0.9999 +- [ ] No silent weight drops (log loaded vs expected keys) +- [ ] Single forward pass matches reference (cos_sim > 0.999) +- [ ] Residual connections use float32 accumulation +- [ ] Attention computation matches reference precision +- [ ] Modulation/gate vectors in float32 (if reference uses autocast) +- [ ] No underscore-prefixed module attributes (MLX ignores them) + +### Scheduler +- [ ] Sigma values match reference at every step (diff < 1e-6) +- [ ] Timestep format correct (int vs float, scale factor) +- [ ] Dynamic shifting formula copied exactly +- [ ] Step function returns correct dtype (float32) +- [ ] Boundary conditions: lambda(-inf) at sigma=1, lambda(+inf) at sigma=0 +- [ ] Higher-order coefficients computed (not hardcoded) for UniPC/DPM++ + +### Text Encoder +- [ ] Embedding statistics reasonable (0.01 < std < 10) +- [ ] Text influence > 0 (ablation test) +- [ ] Tokenization matches (special tokens, padding, max length) +- [ ] Text preprocessing matches (ftfy, html unescape, whitespace normalization) +- [ ] T5/CLIP attention precision (float32 softmax if no 1/sqrt(d) scaling) +- [ ] Separate embeddings for dual-model setups (if applicable) + +### VAE +- [ ] Denormalization constants match training pipeline +- [ ] Per-chunk decoding (not concatenated) +- [ ] Temporal frame count correct (account for causal padding) +- [ ] Weight keys mapped correctly (encoder vs decoder) +- [ ] Weights stored/loaded in float32 (not bfloat16) +- [ ] Temporal downsample/upsample order matches reference +- [ ] RMS_norm vs L2_norm: check actual implementation, not class name +- [ ] Chunked encoding for I2V (if applicable) +- [ ] Reshape axis ordering correct ([B,C,T,H,W] → transpose before reshape) + +### Pipeline Orchestration +- [ ] Position encodings consistent across stages/chunks +- [ ] History slicing and conditioning correct +- [ ] Noise generation matches (distribution, correlation structure) +- [ ] Multi-chunk output visually consistent (no progressive degradation) +- [ ] Dimension auto-alignment (divisible by patch_size × vae_stride) +- [ ] Dtype-aware padding (mx.zeros with explicit dtype) + +### Output +- [ ] Frame count matches expected (account for warmup/trim) +- [ ] FPS correct +- [ ] Color range [0, 255] uint8 for video +- [ ] No first-frame duplication artifacts +- [ ] Video codec correct (imageio/libx264 preferred over cv2/mp4v on macOS) + +### Performance +- [ ] No float32 promotion cascades (check with profiler) +- [ ] Using mx.fast kernels (rms_norm, layer_norm, sdpa) +- [ ] Time embedding computed once per sample (not per position) +- [ ] Memory cleanup (delete temporaries before mx.eval) + +--- + +## 13. Diagnostic Tools + +### General video diagnostics (`scripts/video/`) + +| Script | Purpose | +|--------|---------| +| `compare_videos.py` | PSNR, SSIM, temporal coherence, color fidelity between two videos | +| `video_quality.py` | Sharpness, stability, defect detection, chunk boundary analysis | + +```bash +# Quick quality check +python scripts/video/video_quality.py output.mp4 --chunk-size 32 + +# Compare against reference +python scripts/video/compare_videos.py reference.mp4 output.mp4 --diff-video diff.mp4 +``` + +### Model-specific diagnostics (`scripts/helios/`) + +| Script | Purpose | +|--------|---------| +| `analyze_boundaries.py` | Detailed boundary quality metrics for Helios | +| `run_reference.py` | Run PyTorch reference on MPS | +| `compare_pipelines.py` | Compare scheduler/pipeline mechanics | +| `compare_models.py` | Cross-framework model output comparison | + +### Inline debugging pattern + +Add temporary debug output to the diffusion loop: + +```python +for i, sigma in enumerate(sigmas): + flow = model(latent, sigma, text_emb) + latent = scheduler.step(latent, flow, sigma, sigma_next) + + # Debug: track statistics + print(f"[step {i}] sigma={sigma:.4f} " + f"latent: mean={latent.mean():.6f} std={latent.std():.6f} " + f"flow: mean={flow.mean():.6f} std={flow.std():.6f}") + + # Debug: save for cross-framework comparison + if os.environ.get("DEBUG"): + mx.save(f"/tmp/debug_step_{i}.npz", { + "latent": latent, "flow": flow, "sigma": mx.array(sigma) + }) +``` + +--- + +## Key Takeaways + +1. **Precision is the #1 bug source** — bfloat16 residuals, scheduler math, type + promotion, modulation vectors. Copy the reference's `.float()` and `autocast` zones. + +2. **Don't add what the reference doesn't have** — cross-fade, overlap decode, + temporal blending. If the reference works without it, you probably have a bug + elsewhere. + +3. **Silent failures are the hardest bugs** — underscore-prefixed weights, `strict=False` + weight loading, wrong normalization class names. Always verify weight load counts + and output statistics. + +4. **Component isolation → integration testing** — verify each part matches, then + debug their interaction. + +5. **Statistical comparison beats visual inspection** — mean drift, contrast ratios, + and cosine similarity catch bugs before they're visible. + +6. **Autoregressive errors compound** — a 1% error per chunk becomes 10% by chunk 10. + Fix precision first, add corrections second. + +7. **MLX has unique pitfalls** — underscore attribute names, reshape axis ordering, + dtype-unaware padding, and float32 promotion cascades. Know your framework. + +8. **Text preprocessing matters** — Unicode normalization, fullwidth chars, HTML entities. + A single mismatched comma can break CFG guidance. + +9. **VAE is deceptively complex** — causal padding, temporal frame counts, chunked vs + batch processing, norm implementations. Budget significant debugging time for VAE. diff --git a/scripts/video/compare_videos.py b/scripts/video/compare_videos.py new file mode 100644 index 0000000..1d18804 --- /dev/null +++ b/scripts/video/compare_videos.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +"""Compare two videos frame-by-frame with quality metrics. + +Useful for validating MLX ports against reference PyTorch implementations. +Reports PSNR, SSIM, per-frame differences, temporal coherence, and color +fidelity. Optionally saves a side-by-side diff video. + +Usage: + # Basic comparison + python scripts/video/compare_videos.py reference.mp4 test.mp4 + + # Save side-by-side diff visualization + python scripts/video/compare_videos.py ref.mp4 test.mp4 --diff-video diff.mp4 + + # Compare only first 64 frames + python scripts/video/compare_videos.py ref.mp4 test.mp4 --max-frames 64 + + # Adjust SSIM window size (default: 7) + python scripts/video/compare_videos.py ref.mp4 test.mp4 --ssim-win 11 +""" + +import argparse +import sys + +import cv2 +import numpy as np + + +def load_video(path, max_frames=None): + """Load video frames as float32 numpy arrays (0-255 range).""" + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + print(f"Error: cannot open {path}") + sys.exit(1) + + fps = cap.get(cv2.CAP_PROP_FPS) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(frame.astype(np.float32)) + if max_frames and len(frames) >= max_frames: + break + cap.release() + return frames, fps + + +def compute_psnr(a, b): + """Peak Signal-to-Noise Ratio between two frames.""" + mse = np.mean((a - b) ** 2) + if mse == 0: + return float("inf") + return 10 * np.log10(255.0**2 / mse) + + +def compute_ssim(a, b, win_size=7): + """Structural Similarity Index (per-channel, averaged). + + Uses the standard SSIM formula with default constants. + """ + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + + kernel = cv2.getGaussianKernel(win_size, 1.5) + window = kernel @ kernel.T + + ssim_channels = [] + for c in range(a.shape[2]): + ac, bc = a[:, :, c], b[:, :, c] + mu_a = cv2.filter2D(ac, -1, window) + mu_b = cv2.filter2D(bc, -1, window) + + mu_a_sq = mu_a**2 + mu_b_sq = mu_b**2 + mu_ab = mu_a * mu_b + + sigma_a_sq = cv2.filter2D(ac**2, -1, window) - mu_a_sq + sigma_b_sq = cv2.filter2D(bc**2, -1, window) - mu_b_sq + sigma_ab = cv2.filter2D(ac * bc, -1, window) - mu_ab + + num = (2 * mu_ab + C1) * (2 * sigma_ab + C2) + den = (mu_a_sq + mu_b_sq + C1) * (sigma_a_sq + sigma_b_sq + C2) + ssim_map = num / den + ssim_channels.append(np.mean(ssim_map)) + + return np.mean(ssim_channels) + + +def temporal_coherence(frames): + """Mean frame-to-frame difference (lower = smoother).""" + if len(frames) < 2: + return 0.0 + diffs = [] + for i in range(1, len(frames)): + diffs.append(np.mean(np.abs(frames[i] - frames[i - 1]))) + return np.mean(diffs) + + +def color_histogram_distance(a, b, bins=64): + """Chi-squared distance between color histograms.""" + dist = 0.0 + for c in range(3): + ha, _ = np.histogram(a[:, :, c], bins=bins, range=(0, 256)) + hb, _ = np.histogram(b[:, :, c], bins=bins, range=(0, 256)) + ha = ha.astype(np.float64) / (ha.sum() + 1e-10) + hb = hb.astype(np.float64) / (hb.sum() + 1e-10) + dist += np.sum((ha - hb) ** 2 / (ha + hb + 1e-10)) / 2 + return dist / 3 + + +def make_diff_frame(a, b, scale=5.0): + """Create a heatmap visualization of the absolute difference.""" + diff = np.mean(np.abs(a - b), axis=2) + diff_scaled = np.clip(diff * scale, 0, 255).astype(np.uint8) + heatmap = cv2.applyColorMap(diff_scaled, cv2.COLORMAP_JET) + return heatmap + + +def analyze(ref_frames, test_frames, ssim_win=7): + """Compute per-frame and aggregate metrics.""" + n = min(len(ref_frames), len(test_frames)) + + psnrs = [] + ssims = [] + mean_diffs = [] + max_diffs = [] + color_dists = [] + + for i in range(n): + r, t = ref_frames[i], test_frames[i] + psnrs.append(compute_psnr(r, t)) + ssims.append(compute_ssim(r, t, ssim_win)) + absdiff = np.abs(r - t) + mean_diffs.append(np.mean(absdiff)) + max_diffs.append(np.max(absdiff)) + color_dists.append(color_histogram_distance(r, t)) + + ref_tc = temporal_coherence(ref_frames[:n]) + test_tc = temporal_coherence(test_frames[:n]) + + return { + "num_frames": n, + "psnr": np.array(psnrs), + "ssim": np.array(ssims), + "mean_diff": np.array(mean_diffs), + "max_diff": np.array(max_diffs), + "color_dist": np.array(color_dists), + "ref_temporal_coherence": ref_tc, + "test_temporal_coherence": test_tc, + } + + +def print_report(results, ref_path, test_path): + """Print a formatted comparison report.""" + n = results["num_frames"] + psnr = results["psnr"] + ssim = results["ssim"] + md = results["mean_diff"] + mx = results["max_diff"] + cd = results["color_dist"] + + print("=" * 72) + print("VIDEO COMPARISON REPORT") + print("=" * 72) + print(f" Reference: {ref_path}") + print(f" Test: {test_path}") + print(f" Frames compared: {n}") + print() + + print("AGGREGATE METRICS") + print("-" * 40) + print(f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}") + print(f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}") + print(f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}") + print(f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}") + print(f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}") + print() + + print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)") + print("-" * 40) + print(f" Reference: {results['ref_temporal_coherence']:.2f}") + print(f" Test: {results['test_temporal_coherence']:.2f}") + ratio = results["test_temporal_coherence"] / (results["ref_temporal_coherence"] + 1e-10) + print(f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}") + print() + + # Identify worst frames + print("WORST FRAMES (by PSNR)") + print("-" * 40) + worst_idx = np.argsort(psnr)[:5] + for i in worst_idx: + print(f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}") + print() + + # Quality assessment + mean_psnr = np.mean(psnr) + mean_ssim = np.mean(ssim) + print("QUALITY ASSESSMENT") + print("-" * 40) + if mean_psnr > 40: + grade = "Excellent" + elif mean_psnr > 35: + grade = "Good" + elif mean_psnr > 30: + grade = "Fair" + elif mean_psnr > 25: + grade = "Poor" + else: + grade = "Very different" + print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})") + if mean_psnr < 30: + print(" ⚠ Videos differ significantly — likely a bug or different generation seed") + print("=" * 72) + + +def save_diff_video(ref_frames, test_frames, output_path, fps, scale=5.0): + """Save a side-by-side video: reference | test | diff heatmap.""" + n = min(len(ref_frames), len(test_frames)) + h, w = ref_frames[0].shape[:2] + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(output_path, fourcc, fps, (w * 3, h)) + + for i in range(n): + r = ref_frames[i].astype(np.uint8) + t = test_frames[i].astype(np.uint8) + d = make_diff_frame(ref_frames[i], test_frames[i], scale) + combined = np.hstack([r, t, d]) + out.write(combined) + + out.release() + print(f"Diff video saved to {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Compare two videos frame-by-frame with quality metrics" + ) + parser.add_argument("reference", help="Path to reference video") + parser.add_argument("test", help="Path to test video") + parser.add_argument( + "--diff-video", help="Save side-by-side diff visualization to this path" + ) + parser.add_argument( + "--max-frames", type=int, help="Compare only first N frames" + ) + parser.add_argument( + "--ssim-win", type=int, default=7, help="SSIM window size (default: 7)" + ) + parser.add_argument( + "--diff-scale", + type=float, + default=5.0, + help="Diff heatmap amplification (default: 5.0)", + ) + parser.add_argument( + "--csv", help="Export per-frame metrics to CSV file" + ) + args = parser.parse_args() + + print(f"Loading reference: {args.reference}") + ref_frames, ref_fps = load_video(args.reference, args.max_frames) + print(f" → {len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}") + + print(f"Loading test: {args.test}") + test_frames, test_fps = load_video(args.test, args.max_frames) + print(f" → {len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}") + + if ref_frames[0].shape != test_frames[0].shape: + print(f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}") + print("Resizing test frames to match reference...") + h, w = ref_frames[0].shape[:2] + test_frames = [ + cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) + for f in test_frames + ] + + print("Computing metrics...") + results = analyze(ref_frames, test_frames, args.ssim_win) + print() + print_report(results, args.reference, args.test) + + if args.diff_video: + save_diff_video(ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale) + + if args.csv: + import csv + + with open(args.csv, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"]) + for i in range(results["num_frames"]): + writer.writerow([ + i, + f"{results['psnr'][i]:.4f}", + f"{results['ssim'][i]:.6f}", + f"{results['mean_diff'][i]:.4f}", + f"{results['max_diff'][i]:.1f}", + f"{results['color_dist'][i]:.6f}", + ]) + print(f"Per-frame metrics saved to {args.csv}") + + +if __name__ == "__main__": + main() diff --git a/scripts/video/video_quality.py b/scripts/video/video_quality.py new file mode 100644 index 0000000..f756b5a --- /dev/null +++ b/scripts/video/video_quality.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +"""Analyze quality of a single generated video. + +Reports sharpness, temporal stability, color distribution, motion smoothness, +chunk boundary artifacts, and common generation defects. Useful for quick +quality checks during model porting and debugging. + +Usage: + # Basic analysis + python scripts/video/video_quality.py output.mp4 + + # With chunk boundary analysis (e.g., 32 frames/chunk) + python scripts/video/video_quality.py output.mp4 --chunk-size 32 + + # Detailed per-frame CSV export + python scripts/video/video_quality.py output.mp4 --csv metrics.csv + + # Analyze specific frame range + python scripts/video/video_quality.py output.mp4 --start 0 --end 64 +""" + +import argparse +import sys + +import cv2 +import numpy as np + + +def load_video(path, start=0, end=None): + """Load video frames as float32 numpy arrays (0-255 range).""" + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + print(f"Error: cannot open {path}") + sys.exit(1) + + fps = cap.get(cv2.CAP_PROP_FPS) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if start > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, start) + + frames = [] + idx = start + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(frame.astype(np.float32)) + idx += 1 + if end and idx >= end: + break + cap.release() + return frames, fps, total + + +def sharpness_laplacian(frame): + """Laplacian variance — higher means sharper.""" + gray = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_BGR2GRAY) + return cv2.Laplacian(gray, cv2.CV_64F).var() + + +def sharpness_gradient(frame): + """Mean gradient magnitude — higher means more edges/detail.""" + gray = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float32) + gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3) + gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3) + return np.mean(np.sqrt(gx**2 + gy**2)) + + +def color_stats(frame): + """Per-channel mean and std in BGR order.""" + means = [np.mean(frame[:, :, c]) for c in range(3)] + stds = [np.std(frame[:, :, c]) for c in range(3)] + return means, stds + + +def detect_uniform_color(frame, std_threshold=15.0): + """Detect if frame is near-uniform (common failure mode).""" + return np.std(frame) < std_threshold + + +def detect_noise(frame, threshold=200.0): + """High Laplacian variance with low gradient can indicate noise.""" + lap = sharpness_laplacian(frame) + grad = sharpness_gradient(frame) + # Noise has high variance but less coherent edges + return lap > threshold and grad < 5.0 + + +def frame_difference(a, b): + """Mean absolute pixel difference between frames.""" + return np.mean(np.abs(a - b)) + + +def optical_flow_magnitude(prev, curr): + """Mean optical flow magnitude (Farneback method).""" + prev_gray = cv2.cvtColor(prev.astype(np.uint8), cv2.COLOR_BGR2GRAY) + curr_gray = cv2.cvtColor(curr.astype(np.uint8), cv2.COLOR_BGR2GRAY) + flow = cv2.calcOpticalFlowFarneback( + prev_gray, curr_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0 + ) + mag = np.sqrt(flow[..., 0] ** 2 + flow[..., 1] ** 2) + return np.mean(mag), np.max(mag) + + +def analyze_video(frames, chunk_size=None, compute_flow=False): + """Compute per-frame and aggregate quality metrics.""" + n = len(frames) + + metrics = { + "sharpness_lap": [], + "sharpness_grad": [], + "brightness": [], + "contrast": [], + "color_mean_b": [], + "color_mean_g": [], + "color_mean_r": [], + "frame_diff": [], + "is_uniform": [], + "is_noisy": [], + } + if compute_flow: + metrics["flow_mean"] = [] + metrics["flow_max"] = [] + + for i in range(n): + f = frames[i] + metrics["sharpness_lap"].append(sharpness_laplacian(f)) + metrics["sharpness_grad"].append(sharpness_gradient(f)) + metrics["brightness"].append(np.mean(f)) + metrics["contrast"].append(np.std(f)) + means, _ = color_stats(f) + metrics["color_mean_b"].append(means[0]) + metrics["color_mean_g"].append(means[1]) + metrics["color_mean_r"].append(means[2]) + metrics["is_uniform"].append(detect_uniform_color(f)) + metrics["is_noisy"].append(detect_noise(f)) + + if i > 0: + metrics["frame_diff"].append(frame_difference(frames[i - 1], f)) + if compute_flow: + fm, fmx = optical_flow_magnitude(frames[i - 1], f) + metrics["flow_mean"].append(fm) + metrics["flow_max"].append(fmx) + else: + metrics["frame_diff"].append(0.0) + if compute_flow: + metrics["flow_mean"].append(0.0) + metrics["flow_max"].append(0.0) + + # Convert to arrays + for k in metrics: + metrics[k] = np.array(metrics[k]) + + # Chunk boundary analysis + if chunk_size and n > chunk_size: + boundaries = list(range(chunk_size, n, chunk_size)) + boundary_metrics = [] + for b in boundaries: + if b < n and b > 0: + pre = metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1] + at = metrics["frame_diff"][b] + ratio = at / (pre + 1e-10) + brightness_jump = metrics["brightness"][b] - metrics["brightness"][b - 1] + contrast_jump = ( + (metrics["contrast"][b] - metrics["contrast"][b - 1]) + / (metrics["contrast"][b - 1] + 1e-10) + * 100 + ) + sharpness_jump = ( + (metrics["sharpness_lap"][b] - metrics["sharpness_lap"][b - 1]) + / (metrics["sharpness_lap"][b - 1] + 1e-10) + * 100 + ) + boundary_metrics.append( + { + "frame": b, + "diff_ratio": ratio, + "brightness_jump": brightness_jump, + "contrast_jump_pct": contrast_jump, + "sharpness_jump_pct": sharpness_jump, + } + ) + metrics["boundaries"] = boundary_metrics + + return metrics + + +def print_report(metrics, path, fps, total_frames, frames_analyzed): + """Print a formatted quality report.""" + sl = metrics["sharpness_lap"] + sg = metrics["sharpness_grad"] + br = metrics["brightness"] + ct = metrics["contrast"] + fd = metrics["frame_diff"] + + print("=" * 72) + print("VIDEO QUALITY REPORT") + print("=" * 72) + print(f" File: {path}") + print(f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}") + duration = total_frames / fps if fps > 0 else 0 + print(f" Duration: {duration:.1f}s") + print() + + # Defect detection + n_uniform = int(np.sum(metrics["is_uniform"])) + n_noisy = int(np.sum(metrics["is_noisy"])) + if n_uniform > 0 or n_noisy > 0: + print("⚠ DEFECTS DETECTED") + print("-" * 40) + if n_uniform: + frames_list = np.where(metrics["is_uniform"])[0][:10] + print(f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}") + if n_noisy: + frames_list = np.where(metrics["is_noisy"])[0][:10] + print(f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}") + print() + + print("SHARPNESS") + print("-" * 40) + print(f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}") + print(f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}") + if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3: + print(" ⚠ High sharpness variation — possible blur artifacts") + print() + + print("BRIGHTNESS & CONTRAST") + print("-" * 40) + print(f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}") + print(f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}") + if np.std(br) > 3.0: + print(" ⚠ Brightness instability — may indicate chunk boundary artifacts") + print() + + print("COLOR DISTRIBUTION (BGR)") + print("-" * 40) + print(f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}") + print(f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}") + print(f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}") + print() + + print("TEMPORAL STABILITY") + print("-" * 40) + fd_nz = fd[1:] # skip first frame (always 0) + if len(fd_nz) > 0: + print(f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}") + if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5: + print(" ⚠ High diff variance — jitter or discontinuities") + if "flow_mean" in metrics: + fm = metrics["flow_mean"][1:] + print(f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}") + print() + + # Chunk boundaries + if "boundaries" in metrics and metrics["boundaries"]: + print("CHUNK BOUNDARIES") + print("-" * 40) + print(f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}") + for bm in metrics["boundaries"]: + print( + f" {bm['frame']:6d}" + f" {bm['diff_ratio']:10.2f}x" + f" {bm['brightness_jump']:+10.1f}" + f" {bm['contrast_jump_pct']:+10.1f}%" + f" {bm['sharpness_jump_pct']:+11.1f}%" + ) + avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]]) + if avg_ratio > 2.0: + print(f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions") + print() + + # Overall grade + print("OVERALL ASSESSMENT") + print("-" * 40) + issues = [] + if n_uniform > 0: + issues.append("uniform/blank frames") + if n_noisy > 0: + issues.append("noisy frames") + if np.std(br) > 3.0: + issues.append("brightness flicker") + if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3: + issues.append("sharpness variation") + if "boundaries" in metrics and metrics["boundaries"]: + avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]]) + if avg_ratio > 2.0: + issues.append("chunk boundary artifacts") + if issues: + print(f" Issues found: {', '.join(issues)}") + else: + print(" ✓ No significant quality issues detected") + print("=" * 72) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze quality of a single generated video" + ) + parser.add_argument("video", help="Path to video file") + parser.add_argument( + "--chunk-size", + type=int, + help="Frames per chunk for boundary analysis (e.g., 32)", + ) + parser.add_argument( + "--start", type=int, default=0, help="Start frame (default: 0)" + ) + parser.add_argument("--end", type=int, help="End frame (default: all)") + parser.add_argument( + "--flow", + action="store_true", + help="Compute optical flow (slower but more detailed)", + ) + parser.add_argument("--csv", help="Export per-frame metrics to CSV") + args = parser.parse_args() + + print(f"Loading: {args.video}") + frames, fps, total = load_video(args.video, args.start, args.end) + h, w = frames[0].shape[:2] + print(f" → {len(frames)} frames, {fps:.1f} fps, {w}x{h}") + + print("Analyzing...") + metrics = analyze_video(frames, args.chunk_size, args.flow) + print() + print_report(metrics, args.video, fps, total, len(frames)) + + if args.csv: + import csv + + keys = [ + "sharpness_lap", "sharpness_grad", "brightness", "contrast", + "color_mean_b", "color_mean_g", "color_mean_r", "frame_diff", + ] + if args.flow: + keys += ["flow_mean", "flow_max"] + + with open(args.csv, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["frame"] + keys) + for i in range(len(frames)): + row = [i] + [f"{metrics[k][i]:.4f}" for k in keys] + writer.writerow(row) + print(f"Per-frame metrics saved to {args.csv}") + + +if __name__ == "__main__": + main()