136 lines
4.6 KiB
Python
136 lines
4.6 KiB
Python
"""Tests for Wan model configuration."""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWanModelConfig:
|
|
"""Tests for WanModelConfig dataclass."""
|
|
|
|
def test_default_values(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig()
|
|
assert config.dim == 5120
|
|
assert config.ffn_dim == 13824
|
|
assert config.num_heads == 40
|
|
assert config.num_layers == 40
|
|
assert config.in_dim == 16
|
|
assert config.out_dim == 16
|
|
assert config.patch_size == (1, 2, 2)
|
|
assert config.vae_stride == (4, 8, 8)
|
|
assert config.vae_z_dim == 16
|
|
assert config.boundary == 0.875
|
|
assert config.sample_shift == 12.0
|
|
assert config.sample_steps == 40
|
|
assert config.sample_guide_scale == (3.0, 4.0)
|
|
assert config.num_train_timesteps == 1000
|
|
assert config.qk_norm is True
|
|
assert config.cross_attn_norm is True
|
|
assert config.text_len == 512
|
|
|
|
def test_head_dim_property(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig()
|
|
assert config.head_dim == 128 # 5120 // 40
|
|
|
|
def test_to_dict_roundtrip(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig()
|
|
d = config.to_dict()
|
|
assert isinstance(d, dict)
|
|
assert d["dim"] == 5120
|
|
assert d["patch_size"] == (1, 2, 2)
|
|
assert d["boundary"] == 0.875
|
|
|
|
def test_t5_config_values(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig()
|
|
assert config.t5_vocab_size == 256384
|
|
assert config.t5_dim == 4096
|
|
assert config.t5_dim_attn == 4096
|
|
assert config.t5_dim_ffn == 10240
|
|
assert config.t5_num_heads == 64
|
|
assert config.t5_num_layers == 24
|
|
assert config.t5_num_buckets == 32
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Wan2.1 Config Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWan21Config:
|
|
"""Tests for Wan2.1 config presets."""
|
|
|
|
def test_wan21_14b_factory(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan21_t2v_14b()
|
|
assert config.model_version == "2.1"
|
|
assert config.dual_model is False
|
|
assert config.dim == 5120
|
|
assert config.ffn_dim == 13824
|
|
assert config.num_heads == 40
|
|
assert config.num_layers == 40
|
|
assert config.head_dim == 128
|
|
assert config.sample_guide_scale == 5.0
|
|
assert config.sample_shift == 5.0
|
|
assert config.sample_steps == 50
|
|
assert config.boundary == 0.0
|
|
|
|
def test_wan21_1_3b_factory(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan21_t2v_1_3b()
|
|
assert config.model_version == "2.1"
|
|
assert config.dual_model is False
|
|
assert config.dim == 1536
|
|
assert config.ffn_dim == 8960
|
|
assert config.num_heads == 12
|
|
assert config.num_layers == 30
|
|
assert config.head_dim == 128 # 1536 // 12
|
|
assert config.sample_guide_scale == 5.0
|
|
|
|
def test_wan22_14b_factory(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan22_t2v_14b()
|
|
assert config.model_version == "2.2"
|
|
assert config.dual_model is True
|
|
assert config.dim == 5120
|
|
assert config.sample_guide_scale == (3.0, 4.0)
|
|
assert config.sample_shift == 12.0
|
|
assert config.sample_steps == 40
|
|
assert config.boundary == 0.875
|
|
|
|
def test_wan21_config_to_dict(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan21_t2v_14b()
|
|
d = config.to_dict()
|
|
assert d["model_version"] == "2.1"
|
|
assert d["dual_model"] is False
|
|
assert d["sample_guide_scale"] == 5.0
|
|
|
|
def test_wan21_1_3b_config_to_dict(self):
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan21_t2v_1_3b()
|
|
d = config.to_dict()
|
|
assert d["dim"] == 1536
|
|
assert d["num_layers"] == 30
|
|
|
|
def test_default_config_is_wan22(self):
|
|
"""Default WanModelConfig() should be Wan2.2 14B."""
|
|
from mlx_video.models.wan2.config import WanModelConfig
|
|
|
|
config = WanModelConfig()
|
|
assert config.model_version == "2.2"
|
|
assert config.dual_model is True
|