feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -6,14 +6,15 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
"""Load and initialize WanModel, with optional quantization support.
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
|
||||
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||
|
||||
Args:
|
||||
model_path: Path to model safetensors file
|
||||
config: WanModelConfig
|
||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||
"""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
@@ -30,6 +31,27 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Apply LoRAs: dequantize+merge for quantized models, weight merge for bf16
|
||||
if loras:
|
||||
if quantization:
|
||||
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
|
||||
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
|
||||
from mlx_video.convert_wan import _load_lora_configs
|
||||
from mlx_video.lora import apply_loras_to_model
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
module_to_loras = _load_lora_configs(loras)
|
||||
apply_loras_to_model(model, module_to_loras)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
else:
|
||||
# Weight merging: fold LoRA into bf16 weights before loading
|
||||
from mlx_video.convert_wan import load_and_apply_loras
|
||||
|
||||
weights = load_and_apply_loras(dict(weights), loras)
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user