325 lines
12 KiB
Python
325 lines
12 KiB
Python
"""Tests for Wan weight conversion utilities."""
|
|
|
|
import logging
|
|
|
|
import mlx.core as mx
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Transformer Weight Conversion Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSanitizeTransformerWeights:
|
|
def test_patch_embedding_reshape(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
|
"patch_embedding.bias": mx.random.normal((5120,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "patch_embedding_proj.weight" in out
|
|
assert "patch_embedding_proj.bias" in out
|
|
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
|
|
|
|
def test_text_embedding_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"text_embedding.0.weight": mx.zeros((64, 32)),
|
|
"text_embedding.0.bias": mx.zeros((64,)),
|
|
"text_embedding.2.weight": mx.zeros((64, 64)),
|
|
"text_embedding.2.bias": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "text_embedding_0.weight" in out
|
|
assert "text_embedding_0.bias" in out
|
|
assert "text_embedding_1.weight" in out
|
|
assert "text_embedding_1.bias" in out
|
|
|
|
def test_time_embedding_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"time_embedding.0.weight": mx.zeros((64, 32)),
|
|
"time_embedding.2.weight": mx.zeros((64, 64)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "time_embedding_0.weight" in out
|
|
assert "time_embedding_1.weight" in out
|
|
|
|
def test_time_projection_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"time_projection.1.weight": mx.zeros((384, 64)),
|
|
"time_projection.1.bias": mx.zeros((384,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "time_projection.weight" in out
|
|
assert "time_projection.bias" in out
|
|
|
|
def test_ffn_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.0.bias": mx.zeros((128,)),
|
|
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
|
"blocks.0.ffn.2.bias": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "blocks.0.ffn.fc1.weight" in out
|
|
assert "blocks.0.ffn.fc1.bias" in out
|
|
assert "blocks.0.ffn.fc2.weight" in out
|
|
assert "blocks.0.ffn.fc2.bias" in out
|
|
|
|
def test_freqs_skipped(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"freqs": mx.zeros((1024, 64, 2)),
|
|
"blocks.0.norm1.weight": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
assert "freqs" not in out
|
|
assert "blocks.0.norm1.weight" in out
|
|
|
|
def test_passthrough_keys(self):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
|
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
|
|
"blocks.0.self_attn.v.weight": mx.zeros((64, 64)),
|
|
"blocks.0.self_attn.o.weight": mx.zeros((64, 64)),
|
|
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
|
"head.head.weight": mx.zeros((64, 64)),
|
|
"head.modulation": mx.zeros((1, 2, 64)),
|
|
}
|
|
out = sanitize_wan_transformer_weights(weights)
|
|
for key in weights:
|
|
assert key in out
|
|
|
|
def test_no_unconsumed_keys(self, caplog):
|
|
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
|
|
|
weights = {
|
|
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
|
"patch_embedding.bias": mx.random.normal((5120,)),
|
|
"text_embedding.0.weight": mx.zeros((64, 32)),
|
|
"text_embedding.2.weight": mx.zeros((64, 64)),
|
|
"time_embedding.0.weight": mx.zeros((64, 32)),
|
|
"time_embedding.2.weight": mx.zeros((64, 64)),
|
|
"time_projection.1.weight": mx.zeros((384, 64)),
|
|
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
|
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
|
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
|
"head.head.weight": mx.zeros((64, 64)),
|
|
"freqs": mx.zeros((1024, 64, 2)),
|
|
}
|
|
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
|
sanitize_wan_transformer_weights(weights)
|
|
assert "Unconsumed" not in caplog.text
|
|
|
|
|
|
class TestSanitizeT5Weights:
|
|
def test_gate_rename(self):
|
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
|
|
|
weights = {
|
|
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
|
}
|
|
out = sanitize_wan_t5_weights(weights)
|
|
assert "blocks.0.ffn.gate_proj.weight" in out
|
|
assert "blocks.0.ffn.fc1.weight" in out
|
|
assert "blocks.0.ffn.fc2.weight" in out
|
|
|
|
def test_passthrough(self):
|
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
|
|
|
weights = {
|
|
"token_embedding.weight": mx.zeros((100, 64)),
|
|
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
|
|
"norm.weight": mx.zeros((64,)),
|
|
}
|
|
out = sanitize_wan_t5_weights(weights)
|
|
for key in weights:
|
|
assert key in out
|
|
|
|
def test_no_unconsumed_keys(self, caplog):
|
|
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
|
|
|
weights = {
|
|
"token_embedding.weight": mx.zeros((100, 64)),
|
|
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
|
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
|
"norm.weight": mx.zeros((64,)),
|
|
}
|
|
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
|
sanitize_wan_t5_weights(weights)
|
|
assert "Unconsumed" not in caplog.text
|
|
|
|
|
|
class TestSanitizeVAEWeights:
|
|
def test_conv3d_transpose(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
|
|
weights = {
|
|
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
|
|
|
|
def test_conv2d_transpose(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
|
|
weights = {
|
|
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
|
|
|
|
def test_non_conv_passthrough(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
|
|
weights = {
|
|
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
|
"decoder.bias": mx.zeros((16,)),
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["decoder.norm.weight"].shape == (64,)
|
|
assert out["decoder.bias"].shape == (16,)
|
|
|
|
def test_mixed_weights(self):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
|
|
weights = {
|
|
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
|
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
|
|
"linear.weight": mx.zeros((8, 4)), # 2D
|
|
"norm.weight": mx.zeros((8,)), # 1D
|
|
}
|
|
out = sanitize_wan_vae_weights(weights)
|
|
assert out["conv3d.weight"].shape == (8, 3, 3, 3, 4)
|
|
assert out["conv2d.weight"].shape == (8, 3, 3, 4)
|
|
assert out["linear.weight"].shape == (8, 4)
|
|
assert out["norm.weight"].shape == (8,)
|
|
|
|
def test_no_unconsumed_keys(self, caplog):
|
|
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
|
|
|
weights = {
|
|
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
|
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
|
"decoder.norm.weight": mx.zeros((64,)),
|
|
"decoder.bias": mx.zeros((16,)),
|
|
}
|
|
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
|
sanitize_wan_vae_weights(weights)
|
|
assert "Unconsumed" not in caplog.text
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Wan2.1 Conversion Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWan21Convert:
|
|
"""Tests for Wan2.1 conversion support."""
|
|
|
|
def test_auto_detect_wan21(self, tmp_path):
|
|
"""Auto-detect single-model directory as Wan2.1."""
|
|
# Create a Wan2.1-style directory (no low_noise_model subdir)
|
|
(tmp_path / "dummy.safetensors").touch()
|
|
# The auto-detect logic: no low_noise_model dir → 2.1
|
|
|
|
low = tmp_path / "low_noise_model"
|
|
assert not low.exists()
|
|
# Simulates auto detection
|
|
version = "2.2" if low.exists() else "2.1"
|
|
assert version == "2.1"
|
|
|
|
def test_auto_detect_wan22(self, tmp_path):
|
|
"""Auto-detect dual-model directory as Wan2.2."""
|
|
(tmp_path / "low_noise_model").mkdir()
|
|
(tmp_path / "high_noise_model").mkdir()
|
|
|
|
low = tmp_path / "low_noise_model"
|
|
assert low.exists()
|
|
version = "2.2" if low.exists() else "2.1"
|
|
assert version == "2.2"
|
|
|
|
def test_wan21_config_saved_correctly(self):
|
|
"""Verify config dict has correct fields for Wan2.1."""
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan21_t2v_14b()
|
|
d = config.to_dict()
|
|
assert d["model_version"] == "2.1"
|
|
assert d["dual_model"] is False
|
|
assert d["sample_steps"] == 50
|
|
assert d["sample_shift"] == 5.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Encoder Weight Sanitization Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSanitizeEncoderWeights:
|
|
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
|
|
|
def test_exclude_encoder_by_default(self):
|
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
|
|
|
weights = {
|
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
}
|
|
out = sanitize_wan22_vae_weights(weights, include_encoder=False)
|
|
assert "conv2.weight" in out
|
|
assert not any("encoder" in k or k.startswith("conv1") for k in out)
|
|
|
|
def test_include_encoder(self):
|
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
|
|
|
weights = {
|
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
}
|
|
out = sanitize_wan22_vae_weights(weights, include_encoder=True)
|
|
assert "encoder.conv1.weight" in out
|
|
assert "conv1.weight" in out
|
|
assert "conv2.weight" in out
|
|
|
|
def test_no_unconsumed_keys(self, caplog):
|
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
|
|
|
weights = {
|
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
}
|
|
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
|
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
|
assert "Unconsumed" not in caplog.text
|
|
|
|
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
|
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
|
|
|
weights = {
|
|
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
|
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
|
}
|
|
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
|
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
|
assert "Unconsumed" not in caplog.text
|