Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -14,8 +14,8 @@ class TestEndToEnd:
|
||||
|
||||
def test_tiny_model_denoise_step(self):
|
||||
"""Simulate one denoising step with tiny model."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
config = _make_tiny_config()
|
||||
@@ -43,8 +43,8 @@ class TestEndToEnd:
|
||||
|
||||
def test_tiny_model_full_loop(self):
|
||||
"""Run a complete (tiny) diffusion loop."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(123)
|
||||
config = _make_tiny_config()
|
||||
@@ -81,7 +81,7 @@ class TestI2VMask:
|
||||
"""Tests for _build_i2v_mask."""
|
||||
|
||||
def test_mask_shapes(self):
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4) # C, T, H, W
|
||||
patch_size = (1, 2, 2)
|
||||
@@ -91,7 +91,7 @@ class TestI2VMask:
|
||||
assert mask_tokens.shape == (1, 20)
|
||||
|
||||
def test_first_frame_zero(self):
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
|
||||
@@ -111,7 +111,7 @@ class TestI2VMaskAlignment:
|
||||
|
||||
def test_mask_with_ti2v_dimensions(self):
|
||||
"""Mask should work with TI2V-5B typical dimensions."""
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||
@@ -132,7 +132,7 @@ class TestI2VMaskAlignment:
|
||||
|
||||
def test_mask_per_token_timestep(self):
|
||||
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
from mlx_video.models.wan_2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (4, 3, 4, 4)
|
||||
patch_size = (1, 2, 2)
|
||||
@@ -201,7 +201,7 @@ class TestDimensionAlignment:
|
||||
|
||||
def test_patchify_valid_after_alignment(self):
|
||||
"""After alignment, patchify should succeed without reshape errors."""
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -235,7 +235,7 @@ class TestDimensionAlignment:
|
||||
|
||||
def test_alignment_with_ti2v_config(self):
|
||||
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_ti2v_5b()
|
||||
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||
|
||||
Reference in New Issue
Block a user