feat(wan): Add LoRA with improved quantization pipeline

This commit is contained in:
Daniel
2026-02-28 14:11:13 +01:00
parent dbab95ec45
commit 849cc45d84
17 changed files with 1852 additions and 111 deletions

View File

@@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
conversion (channels-first → channels-last) is needed.
"""
import logging
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
logger = logging.getLogger(__name__)
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
@@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
Maps PyTorch nn.Sequential indices to our named layers.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
# Skip encoder and conv1 unless requested
if not include_encoder:
if key.startswith("encoder.") or key.startswith("conv1."):
consumed.add(key)
continue
new_key = key
@@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
value = mx.array(np.array(value).squeeze())
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
return sanitized