Refactor Wan model structure by renaming and relocating model imports from model.py to wan2.py, enhancing code organization and clarity across the Wan2 module.
This commit is contained in:
@@ -66,7 +66,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_forward_without_y(self):
|
||||
"""Standard T2V forward pass (no y) still works."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -85,7 +85,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_forward_with_y(self):
|
||||
"""I2V forward pass with y channel concatenation."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
@@ -108,7 +108,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_y_none_is_noop(self):
|
||||
"""Passing y=None should be identical to not passing y."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
@@ -129,7 +129,7 @@ class TestModelYParameter:
|
||||
|
||||
def test_batched_cfg_with_y(self):
|
||||
"""Batched CFG (B=2) with y should work."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
|
||||
config = _make_tiny_i2v_config()
|
||||
model = WanModel(config)
|
||||
@@ -307,7 +307,7 @@ class TestI2VEndToEndPipeline:
|
||||
|
||||
def test_full_i2v_pipeline(self):
|
||||
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
@@ -410,7 +410,7 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_model_selection_by_timestep(self):
|
||||
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(1)
|
||||
@@ -485,7 +485,7 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_guide_scale_tuple_applied_per_model(self):
|
||||
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(2)
|
||||
@@ -545,7 +545,7 @@ class TestDualModelSwitching:
|
||||
|
||||
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(3)
|
||||
|
||||
Reference in New Issue
Block a user