feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Tests for Wan weight conversion utilities."""
|
||||
|
||||
import logging
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -94,6 +96,27 @@ class TestSanitizeTransformerWeights:
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_transformer_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeT5Weights:
|
||||
def test_gate_rename(self):
|
||||
@@ -119,6 +142,19 @@ class TestSanitizeT5Weights:
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_t5_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeVAEWeights:
|
||||
def test_conv3d_transpose(self):
|
||||
@@ -161,6 +197,18 @@ class TestSanitizeVAEWeights:
|
||||
assert out["linear.weight"].shape == (8, 4)
|
||||
assert out["norm.weight"].shape == (8,)
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
||||
"decoder.norm.weight": mx.zeros((64,)),
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_vae_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Conversion Tests
|
||||
@@ -233,3 +281,27 @@ class TestSanitizeEncoderWeights:
|
||||
assert "encoder.conv1.weight" in out
|
||||
assert "conv1.weight" in out
|
||||
assert "conv2.weight" in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user