Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -14,8 +14,8 @@ class TestEndToEnd:
def test_tiny_model_denoise_step(self):
"""Simulate one denoising step with tiny model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(42)
config = _make_tiny_config()
@@ -43,8 +43,8 @@ class TestEndToEnd:
def test_tiny_model_full_loop(self):
"""Run a complete (tiny) diffusion loop."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(123)
config = _make_tiny_config()
@@ -81,7 +81,7 @@ class TestI2VMask:
"""Tests for _build_i2v_mask."""
def test_mask_shapes(self):
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4) # C, T, H, W
patch_size = (1, 2, 2)
@@ -91,7 +91,7 @@ class TestI2VMask:
assert mask_tokens.shape == (1, 20)
def test_first_frame_zero(self):
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4)
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
@@ -111,7 +111,7 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions."""
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
@@ -132,7 +132,7 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan2.generate import _build_i2v_mask
z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2)
@@ -201,7 +201,7 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -235,7 +235,7 @@ class TestDimensionAlignment:
def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan2.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1]