Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def _get_model_freqs(self, dim=64, num_heads=4):
|
||||
"""Instantiate a tiny WanModel and return its .freqs tensor."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
config = WanModelConfig()
|
||||
config.dim = dim
|
||||
@@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def test_three_call_vs_single_call_differ(self):
|
||||
"""Three separate rope_params calls must differ from single call."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
|
||||
d = 128 # head_dim for all Wan models
|
||||
# Reference: three separate calls
|
||||
@@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction:
|
||||
This verifies each axis gets its own independent frequency range
|
||||
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
d_h_dim = 2 * (d // 6) # 42
|
||||
@@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction:
|
||||
axis should be 1.0 (theta^0). A single-call approach would give height
|
||||
starting at ~0.04 and width at ~0.002 instead of 1.0.
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def test_model_freqs_match_manual_construction(self):
|
||||
"""WanModel.freqs should match manually constructed three-call freqs."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
|
||||
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
d = head_dim # 16
|
||||
@@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def test_model_freqs_14b_dimensions(self):
|
||||
"""Verify freq dimensions for 14B-scale head_dim=128."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference:
|
||||
"""Numerically compare MLX and PyTorch frequency tables."""
|
||||
import torch
|
||||
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
|
||||
@@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
This is the key property that was broken by the single-call bug:
|
||||
height/width frequencies were too low to distinguish nearby positions.
|
||||
"""
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
from mlx_video.models.wan2.rope import rope_apply, rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
|
||||
def test_precomputed_matches_online(self):
|
||||
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
||||
from mlx_video.models.wan.rope import (
|
||||
from mlx_video.models.wan2.rope import (
|
||||
rope_apply,
|
||||
rope_params,
|
||||
rope_precompute_cos_sin,
|
||||
|
||||
Reference in New Issue
Block a user