Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -15,7 +15,7 @@ from wan_test_helpers import _make_tiny_config
|
||||
|
||||
class TestQuantizePredicate:
|
||||
def test_matches_self_attention_layers(self):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
@@ -23,7 +23,7 @@ class TestQuantizePredicate:
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_cross_attention_layers(self):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
@@ -31,14 +31,14 @@ class TestQuantizePredicate:
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_ffn_layers(self):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
||||
|
||||
def test_rejects_embeddings(self):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for path in [
|
||||
@@ -49,13 +49,13 @@ class TestQuantizePredicate:
|
||||
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
||||
|
||||
def test_rejects_norms(self):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
||||
|
||||
def test_rejects_non_quantizable_modules(self):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
# Even if path matches, module must have to_quantized
|
||||
@@ -63,7 +63,7 @@ class TestQuantizePredicate:
|
||||
|
||||
def test_all_10_patterns_covered(self):
|
||||
"""Verify exactly 10 layer patterns are targeted."""
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
patterns = [
|
||||
@@ -90,8 +90,8 @@ class TestQuantizePredicate:
|
||||
class TestQuantizeRoundTrip:
|
||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||
"""Helper: create model, quantize, save to tmp_path."""
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -116,7 +116,7 @@ class TestQuantizeRoundTrip:
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
@@ -136,7 +136,7 @@ class TestQuantizeRoundTrip:
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
||||
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
@@ -151,7 +151,7 @@ class TestQuantizeRoundTrip:
|
||||
config = _make_tiny_config()
|
||||
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
@@ -164,7 +164,7 @@ class TestQuantizeRoundTrip:
|
||||
|
||||
def test_loading_without_quantization_flag(self, tmp_path):
|
||||
"""Loading a non-quantized model should have standard Linear layers."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -172,7 +172,7 @@ class TestQuantizeRoundTrip:
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
from mlx_video.models.wan_2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(model_path, config, quantization=None)
|
||||
|
||||
@@ -187,8 +187,8 @@ class TestQuantizeRoundTrip:
|
||||
|
||||
class TestQuantizedInference:
|
||||
def _make_quantized_model(self, config, bits=4):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -238,8 +238,8 @@ class TestQuantizedInference:
|
||||
|
||||
def test_quantized_output_differs_from_unquantized(self):
|
||||
"""Sanity check: quantization should change the weights."""
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
mx.random.seed(42)
|
||||
@@ -271,8 +271,8 @@ class TestQuantizedInference:
|
||||
class TestQuantizationConfig:
|
||||
def test_config_metadata_written(self, tmp_path):
|
||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -295,8 +295,8 @@ class TestQuantizationConfig:
|
||||
assert cfg["quantization"]["group_size"] == 64
|
||||
|
||||
def test_config_metadata_8bit(self, tmp_path):
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -316,8 +316,8 @@ class TestQuantizationConfig:
|
||||
|
||||
def test_dual_model_quantization(self, tmp_path):
|
||||
"""Verify dual-model quantization writes both model files."""
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user