Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -10,7 +10,7 @@ class TestWanModelConfig:
"""Tests for WanModelConfig dataclass."""
def test_default_values(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.dim == 5120
@@ -32,13 +32,13 @@ class TestWanModelConfig:
assert config.text_len == 512
def test_head_dim_property(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
d = config.to_dict()
@@ -48,7 +48,7 @@ class TestWanModelConfig:
assert d["boundary"] == 0.875
def test_t5_config_values(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.t5_vocab_size == 256384
@@ -69,7 +69,7 @@ class TestWan21Config:
"""Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1"
@@ -85,7 +85,7 @@ class TestWan21Config:
assert config.boundary == 0.0
def test_wan21_1_3b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1"
@@ -98,7 +98,7 @@ class TestWan21Config:
assert config.sample_guide_scale == 5.0
def test_wan22_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2"
@@ -110,7 +110,7 @@ class TestWan21Config:
assert config.boundary == 0.875
def test_wan21_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
@@ -119,7 +119,7 @@ class TestWan21Config:
assert d["sample_guide_scale"] == 5.0
def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict()
@@ -128,7 +128,7 @@ class TestWan21Config:
def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig()
assert config.model_version == "2.2"