Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -23,7 +23,7 @@ class TestI2VConfig:
"""Test I2V-14B config preset."""
def test_wan22_i2v_14b_preset(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
assert config.model_type == "i2v"
@@ -39,7 +39,7 @@ class TestI2VConfig:
assert config.vae_z_dim == 16
def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
i2v = WanModelConfig.wan22_i2v_14b()
t2v = WanModelConfig.wan22_t2v_14b()
@@ -51,7 +51,7 @@ class TestI2VConfig:
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
def test_i2v_serialization_roundtrip(self):
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
d = config.to_dict()
@@ -66,7 +66,7 @@ class TestModelYParameter:
def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -85,7 +85,7 @@ class TestModelYParameter:
def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -108,7 +108,7 @@ class TestModelYParameter:
def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -129,7 +129,7 @@ class TestModelYParameter:
def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -158,7 +158,7 @@ class TestVAEEncoder:
"""Test Wan2.1 VAE encoder."""
def test_encoder3d_instantiation(self):
from mlx_video.models.wan2.vae import Encoder3d
from mlx_video.models.wan_2.vae import Encoder3d
enc = Encoder3d(
dim=32, z_dim=8
@@ -169,7 +169,7 @@ class TestVAEEncoder:
def test_encoder3d_output_shape(self):
"""Encoder should downsample spatially by 8x and temporally by 4x."""
from mlx_video.models.wan2.vae import Encoder3d
from mlx_video.models.wan_2.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8)
# Random input: [B=1, 3, T=5, H=32, W=32]
@@ -186,7 +186,7 @@ class TestVAEEncoder:
def test_wan_vae_encode(self):
"""WanVAE with encoder=True should produce normalized latents."""
from mlx_video.models.wan2.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
# Input: [B=1, 3, T=5, H=32, W=32]
@@ -198,7 +198,7 @@ class TestVAEEncoder:
def test_wan_vae_encoder_flag(self):
"""WanVAE without encoder flag should not have encoder attribute."""
from mlx_video.models.wan2.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, "encoder")
@@ -211,7 +211,7 @@ class TestResampleDownsample:
"""Test downsample modes in Resample."""
def test_downsample2d(self):
from mlx_video.models.wan2.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="downsample2d")
x = mx.random.normal((1, 16, 2, 8, 8))
@@ -221,7 +221,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_downsample3d(self):
from mlx_video.models.wan2.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="downsample3d")
x = mx.random.normal((1, 16, 4, 8, 8))
@@ -231,7 +231,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_upsample2d_still_works(self):
from mlx_video.models.wan2.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="upsample2d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -240,7 +240,7 @@ class TestResampleDownsample:
assert out.shape == (1, 8, 2, 8, 8)
def test_upsample3d_still_works(self):
from mlx_video.models.wan2.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="upsample3d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -307,9 +307,9 @@ class TestI2VEndToEndPipeline:
def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.vae import WanVAE
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.vae import WanVAE
mx.random.seed(0)
@@ -410,8 +410,8 @@ class TestDualModelSwitching:
def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(1)
config = _make_tiny_i2v_config()
@@ -485,8 +485,8 @@ class TestDualModelSwitching:
def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(2)
config = _make_tiny_i2v_config()
@@ -545,8 +545,8 @@ class TestDualModelSwitching:
def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(3)
config = _make_tiny_config()