Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -14,8 +14,8 @@ class TestEndToEnd:
def test_tiny_model_denoise_step(self):
"""Simulate one denoising step with tiny model."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(42)
config = _make_tiny_config()
@@ -43,8 +43,8 @@ class TestEndToEnd:
def test_tiny_model_full_loop(self):
"""Run a complete (tiny) diffusion loop."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(123)
config = _make_tiny_config()
@@ -81,7 +81,7 @@ class TestI2VMask:
"""Tests for _build_i2v_mask."""
def test_mask_shapes(self):
from mlx_video.models.wan2.generate import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4) # C, T, H, W
patch_size = (1, 2, 2)
@@ -91,7 +91,7 @@ class TestI2VMask:
assert mask_tokens.shape == (1, 20)
def test_first_frame_zero(self):
from mlx_video.models.wan2.generate import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4)
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
@@ -111,7 +111,7 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions."""
from mlx_video.models.wan2.generate import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
@@ -132,7 +132,7 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.models.wan2.generate import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2)
@@ -201,7 +201,7 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -235,7 +235,7 @@ class TestDimensionAlignment:
def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1]