Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -14,8 +14,8 @@ class TestEndToEnd:
|
||||
|
||||
def test_tiny_model_denoise_step(self):
|
||||
"""Simulate one denoising step with tiny model."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(42)
|
||||
config = _make_tiny_config()
|
||||
@@ -43,8 +43,8 @@ class TestEndToEnd:
|
||||
|
||||
def test_tiny_model_full_loop(self):
|
||||
"""Run a complete (tiny) diffusion loop."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(123)
|
||||
config = _make_tiny_config()
|
||||
@@ -81,7 +81,7 @@ class TestI2VMask:
|
||||
"""Tests for _build_i2v_mask."""
|
||||
|
||||
def test_mask_shapes(self):
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4) # C, T, H, W
|
||||
patch_size = (1, 2, 2)
|
||||
@@ -91,7 +91,7 @@ class TestI2VMask:
|
||||
assert mask_tokens.shape == (1, 20)
|
||||
|
||||
def test_first_frame_zero(self):
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (48, 5, 4, 4)
|
||||
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
|
||||
@@ -111,7 +111,7 @@ class TestI2VMaskAlignment:
|
||||
|
||||
def test_mask_with_ti2v_dimensions(self):
|
||||
"""Mask should work with TI2V-5B typical dimensions."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
|
||||
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||
@@ -132,7 +132,7 @@ class TestI2VMaskAlignment:
|
||||
|
||||
def test_mask_per_token_timestep(self):
|
||||
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
from mlx_video.models.wan2.generate import _build_i2v_mask
|
||||
|
||||
z_shape = (4, 3, 4, 4)
|
||||
patch_size = (1, 2, 2)
|
||||
@@ -201,7 +201,7 @@ class TestDimensionAlignment:
|
||||
|
||||
def test_patchify_valid_after_alignment(self):
|
||||
"""After alignment, patchify should succeed without reshape errors."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -235,7 +235,7 @@ class TestDimensionAlignment:
|
||||
|
||||
def test_alignment_with_ti2v_config(self):
|
||||
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_ti2v_5b()
|
||||
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||
|
||||
Reference in New Issue
Block a user