Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -11,7 +11,7 @@ import mlx.core as mx
|
||||
|
||||
class TestSanitizeTransformerWeights:
|
||||
def test_patch_embedding_reshape(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
@@ -23,7 +23,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
|
||||
|
||||
def test_text_embedding_rename(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
@@ -38,7 +38,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "text_embedding_1.bias" in out
|
||||
|
||||
def test_time_embedding_rename(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
@@ -49,7 +49,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "time_embedding_1.weight" in out
|
||||
|
||||
def test_time_projection_rename(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
@@ -60,7 +60,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "time_projection.bias" in out
|
||||
|
||||
def test_ffn_rename(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
@@ -75,7 +75,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "blocks.0.ffn.fc2.bias" in out
|
||||
|
||||
def test_freqs_skipped(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
@@ -86,7 +86,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert "blocks.0.norm1.weight" in out
|
||||
|
||||
def test_passthrough_keys(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
@@ -102,7 +102,7 @@ class TestSanitizeTransformerWeights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_transformer_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
|
||||
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
@@ -119,14 +119,14 @@ class TestSanitizeTransformerWeights:
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
|
||||
sanitize_wan_transformer_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeT5Weights:
|
||||
def test_gate_rename(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
@@ -139,7 +139,7 @@ class TestSanitizeT5Weights:
|
||||
assert "blocks.0.ffn.fc2.weight" in out
|
||||
|
||||
def test_passthrough(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
@@ -151,7 +151,7 @@ class TestSanitizeT5Weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_t5_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
|
||||
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
@@ -160,14 +160,14 @@ class TestSanitizeT5Weights:
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
|
||||
sanitize_wan_t5_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeVAEWeights:
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
|
||||
@@ -176,7 +176,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
|
||||
@@ -185,7 +185,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
|
||||
|
||||
def test_non_conv_passthrough(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
|
||||
@@ -196,7 +196,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["decoder.bias"].shape == (16,)
|
||||
|
||||
def test_mixed_weights(self):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
|
||||
@@ -211,7 +211,7 @@ class TestSanitizeVAEWeights:
|
||||
assert out["norm.weight"].shape == (8,)
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan2.convert import sanitize_wan_vae_weights
|
||||
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||
@@ -219,7 +219,7 @@ class TestSanitizeVAEWeights:
|
||||
"decoder.norm.weight": mx.zeros((64,)),
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.convert"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
|
||||
sanitize_wan_vae_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
@@ -256,7 +256,7 @@ class TestWan21Convert:
|
||||
|
||||
def test_wan21_config_saved_correctly(self):
|
||||
"""Verify config dict has correct fields for Wan2.1."""
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
@@ -275,7 +275,7 @@ class TestSanitizeEncoderWeights:
|
||||
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
|
||||
|
||||
def test_exclude_encoder_by_default(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
@@ -287,7 +287,7 @@ class TestSanitizeEncoderWeights:
|
||||
assert not any("encoder" in k or k.startswith("conv1") for k in out)
|
||||
|
||||
def test_include_encoder(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
@@ -300,25 +300,25 @@ class TestSanitizeEncoderWeights:
|
||||
assert "conv2.weight" in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan2.vae22"):
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user