format
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
"""Tests for Wan model configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanModelConfig:
|
||||
"""Tests for WanModelConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.dim == 5120
|
||||
assert config.ffn_dim == 13824
|
||||
@@ -33,11 +33,13 @@ class TestWanModelConfig:
|
||||
|
||||
def test_head_dim_property(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.head_dim == 128 # 5120 // 40
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
d = config.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
@@ -47,6 +49,7 @@ class TestWanModelConfig:
|
||||
|
||||
def test_t5_config_values(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.t5_vocab_size == 256384
|
||||
assert config.t5_dim == 4096
|
||||
@@ -61,11 +64,13 @@ class TestWanModelConfig:
|
||||
# Wan2.1 Config Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Config:
|
||||
"""Tests for Wan2.1 config presets."""
|
||||
|
||||
def test_wan21_14b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
@@ -81,6 +86,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan21_1_3b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
assert config.model_version == "2.1"
|
||||
assert config.dual_model is False
|
||||
@@ -93,6 +99,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan22_14b_factory(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
@@ -104,6 +111,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan21_config_to_dict(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
d = config.to_dict()
|
||||
assert d["model_version"] == "2.1"
|
||||
@@ -112,6 +120,7 @@ class TestWan21Config:
|
||||
|
||||
def test_wan21_1_3b_config_to_dict(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
d = config.to_dict()
|
||||
assert d["dim"] == 1536
|
||||
@@ -120,6 +129,7 @@ class TestWan21Config:
|
||||
def test_default_config_is_wan22(self):
|
||||
"""Default WanModelConfig() should be Wan2.2 14B."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig()
|
||||
assert config.model_version == "2.2"
|
||||
assert config.dual_model is True
|
||||
|
||||
Reference in New Issue
Block a user