Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction:
def _get_model_freqs(self, dim=64, num_heads=4):
"""Instantiate a tiny WanModel and return its .freqs tensor."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan2.model import WanModel
config = WanModelConfig()
config.dim = dim
@@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction:
def test_three_call_vs_single_call_differ(self):
"""Three separate rope_params calls must differ from single call."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128 # head_dim for all Wan models
# Reference: three separate calls
@@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction:
This verifies each axis gets its own independent frequency range
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction:
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
d_h_dim = 2 * (d // 6) # 42
@@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction:
axis should be 1.0 (theta^0). A single-call approach would give height
starting at ~0.04 and width at ~0.002 instead of 1.0.
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_match_manual_construction(self):
"""WanModel.freqs should match manually constructed three-call freqs."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
@@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_14b_dimensions(self):
"""Verify freq dimensions for 14B-scale head_dim=128."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference:
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan2.rope import rope_params
d = 128
@@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan.rope import rope_apply, rope_params
from mlx_video.models.wan2.rope import rope_apply, rope_params
d = 128
freqs = mx.concatenate(
@@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs:
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
from mlx_video.models.wan.rope import (
from mlx_video.models.wan2.rope import (
rope_apply,
rope_params,
rope_precompute_cos_sin,