Refactor Wan model structure by renaming and relocating model imports from model.py to wan2.py, enhancing code organization and clarity across the Wan2 module.
This commit is contained in:
@@ -91,7 +91,7 @@ class TestQuantizeRoundTrip:
|
||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||
"""Helper: create model, quantize, save to tmp_path."""
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -164,7 +164,7 @@ class TestQuantizeRoundTrip:
|
||||
|
||||
def test_loading_without_quantization_flag(self, tmp_path):
|
||||
"""Loading a non-quantized model should have standard Linear layers."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -188,7 +188,7 @@ class TestQuantizeRoundTrip:
|
||||
class TestQuantizedInference:
|
||||
def _make_quantized_model(self, config, bits=4):
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
nn.quantize(
|
||||
@@ -239,7 +239,7 @@ class TestQuantizedInference:
|
||||
def test_quantized_output_differs_from_unquantized(self):
|
||||
"""Sanity check: quantization should change the weights."""
|
||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
mx.random.seed(42)
|
||||
@@ -272,7 +272,7 @@ class TestQuantizationConfig:
|
||||
def test_config_metadata_written(self, tmp_path):
|
||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -296,7 +296,7 @@ class TestQuantizationConfig:
|
||||
|
||||
def test_config_metadata_8bit(self, tmp_path):
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -317,7 +317,7 @@ class TestQuantizationConfig:
|
||||
def test_dual_model_quantization(self, tmp_path):
|
||||
"""Verify dual-model quantization writes both model files."""
|
||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user