Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -23,7 +23,7 @@ class TestI2VConfig:
"""Test I2V-14B config preset."""
def test_wan22_i2v_14b_preset(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
assert config.model_type == "i2v"
@@ -39,7 +39,7 @@ class TestI2VConfig:
assert config.vae_z_dim == 16
def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
i2v = WanModelConfig.wan22_i2v_14b()
t2v = WanModelConfig.wan22_t2v_14b()
@@ -51,7 +51,7 @@ class TestI2VConfig:
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
def test_i2v_serialization_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
d = config.to_dict()
@@ -66,7 +66,7 @@ class TestModelYParameter:
def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -85,7 +85,7 @@ class TestModelYParameter:
def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -108,7 +108,7 @@ class TestModelYParameter:
def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -129,7 +129,7 @@ class TestModelYParameter:
def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -158,7 +158,7 @@ class TestVAEEncoder:
"""Test Wan2.1 VAE encoder."""
def test_encoder3d_instantiation(self):
from mlx_video.models.wan.vae import Encoder3d
from mlx_video.models.wan2.vae import Encoder3d
enc = Encoder3d(
dim=32, z_dim=8
@@ -169,7 +169,7 @@ class TestVAEEncoder:
def test_encoder3d_output_shape(self):
"""Encoder should downsample spatially by 8x and temporally by 4x."""
from mlx_video.models.wan.vae import Encoder3d
from mlx_video.models.wan2.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8)
# Random input: [B=1, 3, T=5, H=32, W=32]
@@ -186,7 +186,7 @@ class TestVAEEncoder:
def test_wan_vae_encode(self):
"""WanVAE with encoder=True should produce normalized latents."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
# Input: [B=1, 3, T=5, H=32, W=32]
@@ -198,7 +198,7 @@ class TestVAEEncoder:
def test_wan_vae_encoder_flag(self):
"""WanVAE without encoder flag should not have encoder attribute."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, "encoder")
@@ -211,7 +211,7 @@ class TestResampleDownsample:
"""Test downsample modes in Resample."""
def test_downsample2d(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="downsample2d")
x = mx.random.normal((1, 16, 2, 8, 8))
@@ -221,7 +221,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_downsample3d(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="downsample3d")
x = mx.random.normal((1, 16, 4, 8, 8))
@@ -231,7 +231,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_upsample2d_still_works(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="upsample2d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -240,7 +240,7 @@ class TestResampleDownsample:
assert out.shape == (1, 8, 2, 8, 8)
def test_upsample3d_still_works(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan2.vae import Resample
r = Resample(dim=16, mode="upsample3d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline:
def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.vae import WanVAE
mx.random.seed(0)
@@ -410,8 +410,8 @@ class TestDualModelSwitching:
def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(1)
config = _make_tiny_i2v_config()
@@ -485,8 +485,8 @@ class TestDualModelSwitching:
def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(2)
config = _make_tiny_i2v_config()
@@ -545,8 +545,8 @@ class TestDualModelSwitching:
def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(3)
config = _make_tiny_config()