Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -11,7 +11,7 @@ import mlx.core as mx
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights:
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
@@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights:
assert "text_embedding_1.bias" in out
def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
@@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights:
assert "time_embedding_1.weight" in out
def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
@@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights:
assert "time_projection.bias" in out
def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
@@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.ffn.fc2.bias" in out
def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
@@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.norm1.weight" in out
def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
@@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights:
"head.head.weight": mx.zeros((64, 64)),
"freqs": mx.zeros((1024, 64, 2)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
sanitize_wan_transformer_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -139,7 +139,7 @@ class TestSanitizeT5Weights:
assert "blocks.0.ffn.fc2.weight" in out
def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
@@ -151,7 +151,7 @@ class TestSanitizeT5Weights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
@@ -160,14 +160,14 @@ class TestSanitizeT5Weights:
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
"norm.weight": mx.zeros((64,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
sanitize_wan_t5_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
@@ -176,7 +176,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
@@ -185,7 +185,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
@@ -196,7 +196,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.bias"].shape == (16,)
def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
@@ -211,7 +211,7 @@ class TestSanitizeVAEWeights:
assert out["norm.weight"].shape == (8,)
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
@@ -219,7 +219,7 @@ class TestSanitizeVAEWeights:
"decoder.norm.weight": mx.zeros((64,)),
"decoder.bias": mx.zeros((16,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
sanitize_wan_vae_weights(weights)
assert "Unconsumed" not in caplog.text
@@ -256,7 +256,7 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
@@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
def test_exclude_encoder_by_default(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights:
assert not any("encoder" in k or k.startswith("conv1") for k in out)
def test_include_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights:
assert "conv2.weight" in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "Unconsumed" not in caplog.text
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "Unconsumed" not in caplog.text