feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -7,12 +7,15 @@ Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
|
||||
conversion (channels-first → channels-last) is needed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization for z_dim=48 latent space
|
||||
@@ -774,11 +777,13 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
|
||||
Maps PyTorch nn.Sequential indices to our named layers.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
# Skip encoder and conv1 unless requested
|
||||
if not include_encoder:
|
||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
new_key = key
|
||||
@@ -832,5 +837,10 @@ def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) ->
|
||||
value = mx.array(np.array(value).squeeze())
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
Reference in New Issue
Block a user