Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config
|
||||
|
||||
class TestSinusoidalEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.arange(10).astype(mx.float32)
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
@@ -21,7 +21,7 @@ class TestSinusoidalEmbedding:
|
||||
|
||||
def test_position_zero(self):
|
||||
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0])
|
||||
emb = sinusoidal_embedding_1d(64, pos)
|
||||
@@ -33,7 +33,7 @@ class TestSinusoidalEmbedding:
|
||||
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
|
||||
|
||||
def test_different_positions_differ(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0, 100.0, 999.0])
|
||||
emb = sinusoidal_embedding_1d(128, pos)
|
||||
@@ -50,7 +50,7 @@ class TestSinusoidalEmbedding:
|
||||
|
||||
class TestHead:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
from mlx_video.models.wan2.model import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
B, L = 1, 24
|
||||
@@ -62,7 +62,7 @@ class TestHead:
|
||||
assert out.shape == (B, L, expected_proj_dim)
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
from mlx_video.models.wan2.model import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
assert head.modulation.shape == (1, 2, 64)
|
||||
@@ -78,7 +78,7 @@ class TestWanModel:
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -86,7 +86,7 @@ class TestWanModel:
|
||||
assert num_params > 0
|
||||
|
||||
def test_patchify_shape(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -99,7 +99,7 @@ class TestWanModel:
|
||||
assert patches.shape == (1, 1 * 2 * 2, config.dim)
|
||||
|
||||
def test_patchify_various_sizes(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -115,7 +115,7 @@ class TestWanModel:
|
||||
|
||||
def test_unpatchify_inverse(self):
|
||||
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -131,7 +131,7 @@ class TestWanModel:
|
||||
assert out[0].shape == (config.out_dim, F, H, W)
|
||||
|
||||
def test_forward_pass(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -149,7 +149,7 @@ class TestWanModel:
|
||||
assert out[0].shape == (C, F, H, W)
|
||||
|
||||
def test_forward_batch(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -171,7 +171,7 @@ class TestWanModel:
|
||||
assert o.shape == (C, F, H, W)
|
||||
|
||||
def test_output_is_float32(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -200,7 +200,7 @@ class TestWan21Model:
|
||||
|
||||
def _make_tiny_wan21_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
# Override to tiny values
|
||||
@@ -217,7 +217,7 @@ class TestWan21Model:
|
||||
|
||||
def _make_tiny_wan21_1_3b_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
||||
@@ -234,7 +234,7 @@ class TestWan21Model:
|
||||
|
||||
def test_wan21_tiny_model_forward(self):
|
||||
"""Forward pass with Wan2.1 tiny config."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = self._make_tiny_wan21_config()
|
||||
model = WanModel(config)
|
||||
@@ -252,7 +252,7 @@ class TestWan21Model:
|
||||
|
||||
def test_wan21_1_3b_tiny_model_forward(self):
|
||||
"""Forward pass with Wan2.1 1.3B tiny config."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = self._make_tiny_wan21_1_3b_config()
|
||||
model = WanModel(config)
|
||||
@@ -270,8 +270,8 @@ class TestWan21Model:
|
||||
|
||||
def test_wan21_single_model_loop(self):
|
||||
"""Full diffusion loop with single model (Wan2.1 style)."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
config = self._make_tiny_wan21_config()
|
||||
model = WanModel(config)
|
||||
@@ -305,7 +305,7 @@ class TestWan21Model:
|
||||
|
||||
def test_wan21_vs_wan22_config_differences(self):
|
||||
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
c21 = WanModelConfig.wan21_t2v_14b()
|
||||
c22 = WanModelConfig.wan22_t2v_14b()
|
||||
@@ -333,21 +333,21 @@ class TestPerTokenTimestep:
|
||||
"""Tests for per-token sinusoidal embedding."""
|
||||
|
||||
def test_1d_unchanged(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0, 100.0, 500.0])
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
assert emb.shape == (3, 256)
|
||||
|
||||
def test_2d_per_token(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
assert emb.shape == (2, 3, 256)
|
||||
|
||||
def test_consistency(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
||||
|
||||
pos_1d = mx.array([0.0, 100.0])
|
||||
emb_1d = sinusoidal_embedding_1d(256, pos_1d)
|
||||
|
||||
Reference in New Issue
Block a user