Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -10,7 +10,7 @@ class TestWanModelConfig:
"""Tests for WanModelConfig dataclass."""
def test_default_values(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.dim == 5120
@@ -32,13 +32,13 @@ class TestWanModelConfig:
assert config.text_len == 512
def test_head_dim_property(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
d = config.to_dict()
@@ -48,7 +48,7 @@ class TestWanModelConfig:
assert d["boundary"] == 0.875
def test_t5_config_values(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.t5_vocab_size == 256384
@@ -69,7 +69,7 @@ class TestWan21Config:
"""Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1"
@@ -85,7 +85,7 @@ class TestWan21Config:
assert config.boundary == 0.0
def test_wan21_1_3b_factory(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1"
@@ -98,7 +98,7 @@ class TestWan21Config:
assert config.sample_guide_scale == 5.0
def test_wan22_14b_factory(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2"
@@ -110,7 +110,7 @@ class TestWan21Config:
assert config.boundary == 0.875
def test_wan21_config_to_dict(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
@@ -119,7 +119,7 @@ class TestWan21Config:
assert d["sample_guide_scale"] == 5.0
def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict()
@@ -128,7 +128,7 @@ class TestWan21Config:
def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.model_version == "2.2"