Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -23,7 +23,7 @@ class TestI2VConfig:
|
||||
"""Test I2V-14B config preset."""
|
||||
|
||||
def test_wan22_i2v_14b_preset(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
assert config.model_type == "i2v"
|
||||
@@ -39,7 +39,7 @@ class TestI2VConfig:
|
||||
assert config.vae_z_dim == 16
|
||||
|
||||
def test_i2v_vs_t2v_differences(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
i2v = WanModelConfig.wan22_i2v_14b()
|
||||
t2v = WanModelConfig.wan22_t2v_14b()
|
||||
@@ -51,7 +51,7 @@ class TestI2VConfig:
|
||||
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
|
||||
|
||||
def test_i2v_serialization_roundtrip(self):
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_i2v_14b()
|
||||
d = config.to_dict()
|
||||
@@ -66,7 +66,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_forward_without_y(self):
|
||||
"""Standard T2V forward pass (no y) still works."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -85,7 +85,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_forward_with_y(self):
|
||||
"""I2V forward pass with y channel concatenation."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
@@ -108,7 +108,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_y_none_is_noop(self):
|
||||
"""Passing y=None should be identical to not passing y."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -129,7 +129,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_batched_cfg_with_y(self):
|
||||
"""Batched CFG (B=2) with y should work."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
@@ -158,7 +158,7 @@ class TestVAEEncoder:
|
||||
"""Test Wan2.1 VAE encoder."""
|
||||
|
||||
def test_encoder3d_instantiation(self):
|
||||
from mlx_video.models.wan2.vae import Encoder3d
|
||||
from mlx_video.models.wan_2.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(
|
||||
dim=32, z_dim=8
|
||||
@@ -169,7 +169,7 @@ class TestVAEEncoder:
|
||||
|
||||
def test_encoder3d_output_shape(self):
|
||||
"""Encoder should downsample spatially by 8x and temporally by 4x."""
|
||||
from mlx_video.models.wan2.vae import Encoder3d
|
||||
from mlx_video.models.wan_2.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=32, z_dim=8)
|
||||
# Random input: [B=1, 3, T=5, H=32, W=32]
|
||||
@@ -186,7 +186,7 @@ class TestVAEEncoder:
|
||||
|
||||
def test_wan_vae_encode(self):
|
||||
"""WanVAE with encoder=True should produce normalized latents."""
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
# Input: [B=1, 3, T=5, H=32, W=32]
|
||||
@@ -198,7 +198,7 @@ class TestVAEEncoder:
|
||||
|
||||
def test_wan_vae_encoder_flag(self):
|
||||
"""WanVAE without encoder flag should not have encoder attribute."""
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
||||
assert not hasattr(vae_no_enc, "encoder")
|
||||
@@ -211,7 +211,7 @@ class TestResampleDownsample:
|
||||
"""Test downsample modes in Resample."""
|
||||
|
||||
def test_downsample2d(self):
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 8, 8))
|
||||
@@ -221,7 +221,7 @@ class TestResampleDownsample:
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_downsample3d(self):
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="downsample3d")
|
||||
x = mx.random.normal((1, 16, 4, 8, 8))
|
||||
@@ -231,7 +231,7 @@ class TestResampleDownsample:
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_upsample2d_still_works(self):
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample2d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
@@ -240,7 +240,7 @@ class TestResampleDownsample:
|
||||
assert out.shape == (1, 8, 2, 8, 8)
|
||||
|
||||
def test_upsample3d_still_works(self):
|
||||
from mlx_video.models.wan2.vae import Resample
|
||||
from mlx_video.models.wan_2.vae import Resample
|
||||
|
||||
r = Resample(dim=16, mode="upsample3d")
|
||||
x = mx.random.normal((1, 16, 2, 4, 4))
|
||||
@@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline:
|
||||
|
||||
def test_full_i2v_pipeline(self):
|
||||
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
mx.random.seed(0)
|
||||
|
||||
@@ -410,8 +410,8 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_model_selection_by_timestep(self):
|
||||
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(1)
|
||||
config = _make_tiny_i2v_config()
|
||||
@@ -485,8 +485,8 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_guide_scale_tuple_applied_per_model(self):
|
||||
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(2)
|
||||
config = _make_tiny_i2v_config()
|
||||
@@ -545,8 +545,8 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(3)
|
||||
config = _make_tiny_config()
|
||||
|
||||
Reference in New Issue
Block a user