Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -15,7 +15,7 @@ from wan_test_helpers import _make_tiny_config
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
@@ -23,7 +23,7 @@ class TestQuantizePredicate:
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_cross_attention_layers(self):
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
@@ -31,14 +31,14 @@ class TestQuantizePredicate:
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_ffn_layers(self):
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in [
@@ -49,13 +49,13 @@ class TestQuantizePredicate:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
@@ -63,7 +63,7 @@ class TestQuantizePredicate:
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
@@ -90,8 +90,8 @@ class TestQuantizePredicate:
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config)
nn.quantize(
@@ -116,7 +116,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan2.utils import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(
model_path,
@@ -136,7 +136,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan2.utils import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(
model_path,
@@ -151,7 +151,7 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan2.utils import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(
model_path,
@@ -164,7 +164,7 @@ class TestQuantizeRoundTrip:
def test_loading_without_quantization_flag(self, tmp_path):
"""Loading a non-quantized model should have standard Linear layers."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -172,7 +172,7 @@ class TestQuantizeRoundTrip:
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan2.utils import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
@@ -187,8 +187,8 @@ class TestQuantizeRoundTrip:
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config)
nn.quantize(
@@ -238,8 +238,8 @@ class TestQuantizedInference:
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.models.wan2.convert import _quantize_predicate
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
mx.random.seed(42)
@@ -271,8 +271,8 @@ class TestQuantizedInference:
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan2.convert import _quantize_saved_model
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -295,8 +295,8 @@ class TestQuantizationConfig:
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan2.convert import _quantize_saved_model
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -316,8 +316,8 @@ class TestQuantizationConfig:
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.models.wan2.convert import _quantize_saved_model
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()