Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -23,7 +23,7 @@ class TestI2VConfig:
|
||||
"""Test I2V-14B config preset."""
|
||||
|
||||
def test_wan22_i2v_14b_preset(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
assert config.model_type == "i2v"
|
||||
@@ -39,7 +39,7 @@ class TestI2VConfig:
|
||||
assert config.vae_z_dim == 16
|
||||
|
||||
def test_i2v_vs_t2v_differences(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
i2v = WanModelConfig.wan22_i2v_14b()
|
||||
t2v = WanModelConfig.wan22_t2v_14b()
|
||||
@@ -51,7 +51,7 @@ class TestI2VConfig:
|
||||
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
|
||||
|
||||
def test_i2v_serialization_roundtrip(self):
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
d = config.to_dict()
|
||||
@@ -66,7 +66,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_forward_without_y(self):
|
||||
"""Standard T2V forward pass (no y) still works."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -85,7 +85,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_forward_with_y(self):
|
||||
"""I2V forward pass with y channel concatenation."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
@@ -108,7 +108,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_y_none_is_noop(self):
|
||||
"""Passing y=None should be identical to not passing y."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -129,7 +129,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_batched_cfg_with_y(self):
|
||||
"""Batched CFG (B=2) with y should work."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
@@ -158,7 +158,7 @@ class TestVAEEncoder:
|
||||
"""Test Wan2.1 VAE encoder."""
|
||||
|
||||
def test_encoder3d_instantiation(self):
|
||||
from mlx_video.models.wan.vae import Encoder3d
|
||||
from mlx_video.models.wan2.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(
|
||||
dim=32, z_dim=8
|
||||
@@ -169,7 +169,7 @@ class TestVAEEncoder:
|
||||
|
||||
def test_encoder3d_output_shape(self):
|
||||
"""Encoder should downsample spatially by 8x and temporally by 4x."""
|
||||
from mlx_video.models.wan.vae import Encoder3d
|
||||
from mlx_video.models.wan2.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=32, z_dim=8)
|
||||
# Random input: [B=1, 3, T=5, H=32, W=32]
|
||||
@@ -186,7 +186,7 @@ class TestVAEEncoder:
|
||||
|
||||
def test_wan_vae_encode(self):
|
||||
"""WanVAE with encoder=True should produce normalized latents."""
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
# Input: [B=1, 3, T=5, H=32, W=32]
|
||||
@@ -198,7 +198,7 @@ class TestVAEEncoder:
|
||||
|
||||
def test_wan_vae_encoder_flag(self):
|
||||
"""WanVAE without encoder flag should not have encoder attribute."""
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
||||
assert not hasattr(vae_no_enc, "encoder")
|
||||
@@ -211,7 +211,7 @@ class TestResampleDownsample:
|
||||
"""Test downsample modes in Resample."""
|
||||
|
||||
def test_downsample2d(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 8, 8))
|
||||
@@ -221,7 +221,7 @@ class TestResampleDownsample:
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_downsample3d(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample3d")
|
||||
x = mx.random.normal((1, 16, 4, 8, 8))
|
||||
@@ -231,7 +231,7 @@ class TestResampleDownsample:
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_upsample2d_still_works(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
@@ -240,7 +240,7 @@ class TestResampleDownsample:
|
||||
assert out.shape == (1, 8, 2, 8, 8)
|
||||
|
||||
def test_upsample3d_still_works(self):
|
||||
from mlx_video.models.wan.vae import Resample
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample3d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
@@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline:
|
||||
|
||||
def test_full_i2v_pipeline(self):
|
||||
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
mx.random.seed(0)
|
||||
|
||||
@@ -410,8 +410,8 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_model_selection_by_timestep(self):
|
||||
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(1)
|
||||
config = _make_tiny_i2v_config()
|
||||
@@ -485,8 +485,8 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_guide_scale_tuple_applied_per_model(self):
|
||||
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(2)
|
||||
config = _make_tiny_i2v_config()
|
||||
@@ -545,8 +545,8 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(3)
|
||||
config = _make_tiny_config()
|
||||
|
||||
Reference in New Issue
Block a user