Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -11,7 +11,7 @@ import mlx.core as mx
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights:
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
def test_text_embedding_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
@@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights:
assert "text_embedding_1.bias" in out
def test_time_embedding_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
@@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights:
assert "time_embedding_1.weight" in out
def test_time_projection_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
@@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights:
assert "time_projection.bias" in out
def test_ffn_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
@@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.ffn.fc2.bias" in out
def test_freqs_skipped(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
@@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.norm1.weight" in out
def test_passthrough_keys(self):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
@@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
@@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights:
"head.head.weight": mx.zeros((64, 64)),
"freqs": mx.zeros((1024, 64, 2)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_transformer_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -139,7 +139,7 @@ class TestSanitizeT5Weights:
assert "blocks.0.ffn.fc2.weight" in out
def test_passthrough(self):
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
@@ -151,7 +151,7 @@ class TestSanitizeT5Weights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
@@ -160,14 +160,14 @@ class TestSanitizeT5Weights:
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
"norm.weight": mx.zeros((64,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_t5_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
@@ -176,7 +176,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
def test_conv2d_transpose(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
@@ -185,7 +185,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
def test_non_conv_passthrough(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
@@ -196,7 +196,7 @@ class TestSanitizeVAEWeights:
assert out["decoder.bias"].shape == (16,)
def test_mixed_weights(self):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
@@ -211,7 +211,7 @@ class TestSanitizeVAEWeights:
assert out["norm.weight"].shape == (8,)
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
@@ -219,7 +219,7 @@ class TestSanitizeVAEWeights:
"decoder.norm.weight": mx.zeros((64,)),
"decoder.bias": mx.zeros((16,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_vae_weights(weights)
assert "Unconsumed" not in caplog.text
@@ -256,7 +256,7 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
@@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
def test_exclude_encoder_by_default(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights:
assert not any("encoder" in k or k.startswith("conv1") for k in out)
def test_include_encoder(self):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights:
assert "conv2.weight" in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "Unconsumed" not in caplog.text
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "Unconsumed" not in caplog.text