Refactor Wan model structure by renaming and relocating model imports from model.py to wan2.py, enhancing code organization and clarity across the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:57:29 +01:00
parent 6c63163671
commit b029668cd2
13 changed files with 45 additions and 45 deletions

View File

@@ -66,7 +66,7 @@ class TestModelYParameter:
def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -85,7 +85,7 @@ class TestModelYParameter:
def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -108,7 +108,7 @@ class TestModelYParameter:
def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -129,7 +129,7 @@ class TestModelYParameter:
def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -307,7 +307,7 @@ class TestI2VEndToEndPipeline:
def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan2.vae import WanVAE
@@ -410,7 +410,7 @@ class TestDualModelSwitching:
def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(1)
@@ -485,7 +485,7 @@ class TestDualModelSwitching:
def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(2)
@@ -545,7 +545,7 @@ class TestDualModelSwitching:
def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan2.model import WanModel
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
mx.random.seed(3)