Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -10,7 +10,7 @@ class TestWanModelConfig:
|
||||
"""Tests for WanModelConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.dim == 5120
|
||||
@@ -32,13 +32,13 @@ class TestWanModelConfig:
|
||||
assert config.text_len == 512
|
||||
|
||||
def test_head_dim_property(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.head_dim == 128 # 5120 // 40
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
d = config.to_dict()
|
||||
@@ -48,7 +48,7 @@ class TestWanModelConfig:
|
||||
assert d["boundary"] == 0.875
|
||||
|
||||
def test_t5_config_values(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.t5_vocab_size == 256384
|
||||
@@ -69,7 +69,7 @@ class TestWan21Config:
|
||||
"""Tests for Wan2.1 config presets."""
|
||||
|
||||
def test_wan21_14b_factory(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
assert config.model_version == "2.1"
|
||||
@@ -85,7 +85,7 @@ class TestWan21Config:
|
||||
assert config.boundary == 0.0
|
||||
|
||||
def test_wan21_1_3b_factory(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
assert config.model_version == "2.1"
|
||||
@@ -98,7 +98,7 @@ class TestWan21Config:
|
||||
assert config.sample_guide_scale == 5.0
|
||||
|
||||
def test_wan22_14b_factory(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
assert config.model_version == "2.2"
|
||||
@@ -110,7 +110,7 @@ class TestWan21Config:
|
||||
assert config.boundary == 0.875
|
||||
|
||||
def test_wan21_config_to_dict(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
@@ -119,7 +119,7 @@ class TestWan21Config:
|
||||
assert d["sample_guide_scale"] == 5.0
|
||||
|
||||
def test_wan21_1_3b_config_to_dict(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
d = config.to_dict()
|
||||
@@ -128,7 +128,7 @@ class TestWan21Config:
|
||||
|
||||
def test_default_config_is_wan22(self):
|
||||
"""Default WanModelConfig() should be Wan2.2 14B."""
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.model_version == "2.2"
|
||||
|
||||
Reference in New Issue
Block a user