Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -11,7 +11,7 @@ import mlx.core as mx
|
||||
|
||||
class TestSanitizeTransformerWeights:
|
||||
def test_patch_embedding_reshape(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
@@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
|
||||
|
||||
def test_text_embedding_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
@@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "text_embedding_1.bias" in out
|
||||
|
||||
def test_time_embedding_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
@@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "time_embedding_1.weight" in out
|
||||
|
||||
def test_time_projection_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
@@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "time_projection.bias" in out
|
||||
|
||||
def test_ffn_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
@@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "blocks.0.ffn.fc2.bias" in out
|
||||
|
||||
def test_freqs_skipped(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
@@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "blocks.0.norm1.weight" in out
|
||||
|
||||
def test_passthrough_keys(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
@@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
@@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights:
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
|
||||
sanitize_wan_transformer_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeT5Weights:
|
||||
def test_gate_rename(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
@@ -139,7 +139,7 @@ class TestSanitizeT5Weights:
|
||||
assert "blocks.0.ffn.fc2.weight" in out
|
||||
|
||||
def test_passthrough(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
@@ -151,7 +151,7 @@ class TestSanitizeT5Weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
@@ -160,14 +160,14 @@ class TestSanitizeT5Weights:
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
|
||||
sanitize_wan_t5_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeVAEWeights:
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
||||
@@ -176,7 +176,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
||||
@@ -185,7 +185,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
|
||||
|
||||
def test_non_conv_passthrough(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
||||
@@ -196,7 +196,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["decoder.bias"].shape == (16,)
|
||||
|
||||
def test_mixed_weights(self):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
||||
@@ -211,7 +211,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["norm.weight"].shape == (8,)
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||
@@ -219,7 +219,7 @@ class TestSanitizeVAEWeights:
|
||||
"decoder.norm.weight": mx.zeros((64,)),
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
|
||||
sanitize_wan_vae_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
@@ -256,7 +256,7 @@ class TestWan21Convert:
|
||||
|
||||
def test_wan21_config_saved_correctly(self):
|
||||
"""Verify config dict has correct fields for Wan2.1."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
@@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights:
|
||||
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
||||
|
||||
def test_exclude_encoder_by_default(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
@@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights:
|
||||
assert not any("encoder" in k or k.startswith("conv1") for k in out)
|
||||
|
||||
def test_include_encoder(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
@@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights:
|
||||
assert "conv2.weight" in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user