Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -15,7 +15,7 @@ from wan_test_helpers import _make_tiny_config
|
||||
|
||||
class TestQuantizePredicate:
|
||||
def test_matches_self_attention_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
@@ -23,7 +23,7 @@ class TestQuantizePredicate:
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_cross_attention_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for suffix in ["q", "k", "v", "o"]:
|
||||
@@ -31,14 +31,14 @@ class TestQuantizePredicate:
|
||||
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
|
||||
|
||||
def test_matches_ffn_layers(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
|
||||
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
|
||||
|
||||
def test_rejects_embeddings(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
for path in [
|
||||
@@ -49,13 +49,13 @@ class TestQuantizePredicate:
|
||||
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
|
||||
|
||||
def test_rejects_norms(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
|
||||
|
||||
def test_rejects_non_quantizable_modules(self):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
mock_norm = nn.RMSNorm(64)
|
||||
# Even if path matches, module must have to_quantized
|
||||
@@ -63,7 +63,7 @@ class TestQuantizePredicate:
|
||||
|
||||
def test_all_10_patterns_covered(self):
|
||||
"""Verify exactly 10 layer patterns are targeted."""
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
|
||||
mock_linear = nn.Linear(64, 64)
|
||||
patterns = [
|
||||
@@ -90,8 +90,8 @@ class TestQuantizePredicate:
|
||||
class TestQuantizeRoundTrip:
|
||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||
"""Helper: create model, quantize, save to tmp_path."""
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -116,7 +116,7 @@ class TestQuantizeRoundTrip:
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
@@ -136,7 +136,7 @@ class TestQuantizeRoundTrip:
|
||||
config = _make_tiny_config()
|
||||
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
@@ -151,7 +151,7 @@ class TestQuantizeRoundTrip:
|
||||
config = _make_tiny_config()
|
||||
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(
|
||||
model_path,
|
||||
@@ -164,7 +164,7 @@ class TestQuantizeRoundTrip:
|
||||
|
||||
def test_loading_without_quantization_flag(self, tmp_path):
|
||||
"""Loading a non-quantized model should have standard Linear layers."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -172,7 +172,7 @@ class TestQuantizeRoundTrip:
|
||||
model_path = tmp_path / "model.safetensors"
|
||||
mx.save_safetensors(str(model_path), weights_dict)
|
||||
|
||||
from mlx_video.models.wan.loading import load_wan_model
|
||||
from mlx_video.models.wan2.utils import load_wan_model
|
||||
|
||||
loaded = load_wan_model(model_path, config, quantization=None)
|
||||
|
||||
@@ -187,8 +187,8 @@ class TestQuantizeRoundTrip:
|
||||
|
||||
class TestQuantizedInference:
|
||||
def _make_quantized_model(self, config, bits=4):
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -238,8 +238,8 @@ class TestQuantizedInference:
|
||||
|
||||
def test_quantized_output_differs_from_unquantized(self):
|
||||
"""Sanity check: quantization should change the weights."""
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
mx.random.seed(42)
|
||||
@@ -271,8 +271,8 @@ class TestQuantizedInference:
|
||||
class TestQuantizationConfig:
|
||||
def test_config_metadata_written(self, tmp_path):
|
||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -295,8 +295,8 @@ class TestQuantizationConfig:
|
||||
assert cfg["quantization"]["group_size"] == 64
|
||||
|
||||
def test_config_metadata_8bit(self, tmp_path):
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -316,8 +316,8 @@ class TestQuantizationConfig:
|
||||
|
||||
def test_dual_model_quantization(self, tmp_path):
|
||||
"""Verify dual-model quantization writes both model files."""
|
||||
from mlx_video.convert_wan import _quantize_saved_model
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user