Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def _get_model_freqs(self, dim=64, num_heads=4):
|
||||
"""Instantiate a tiny WanModel and return its .freqs tensor."""
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
from mlx_video.models.wan2.wan2 import WanModel
|
||||
from mlx_video.models.wan_2.config import WanModelConfig
|
||||
from mlx_video.models.wan_2.wan_2 import WanModel
|
||||
|
||||
config = WanModelConfig()
|
||||
config.dim = dim
|
||||
@@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def test_three_call_vs_single_call_differ(self):
|
||||
"""Three separate rope_params calls must differ from single call."""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128 # head_dim for all Wan models
|
||||
# Reference: three separate calls
|
||||
@@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction:
|
||||
This verifies each axis gets its own independent frequency range
|
||||
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
|
||||
"""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
|
||||
"""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
d_h_dim = 2 * (d // 6) # 42
|
||||
@@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction:
|
||||
axis should be 1.0 (theta^0). A single-call approach would give height
|
||||
starting at ~0.04 and width at ~0.002 instead of 1.0.
|
||||
"""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def test_model_freqs_match_manual_construction(self):
|
||||
"""WanModel.freqs should match manually constructed three-call freqs."""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
|
||||
d = head_dim # 16
|
||||
@@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction:
|
||||
|
||||
def test_model_freqs_14b_dimensions(self):
|
||||
"""Verify freq dimensions for 14B-scale head_dim=128."""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference:
|
||||
"""Numerically compare MLX and PyTorch frequency tables."""
|
||||
import torch
|
||||
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
|
||||
d = 128
|
||||
|
||||
@@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
This is the key property that was broken by the single-call bug:
|
||||
height/width frequencies were too low to distinguish nearby positions.
|
||||
"""
|
||||
from mlx_video.models.wan2.rope import rope_apply, rope_params
|
||||
from mlx_video.models.wan_2.rope import rope_apply, rope_params
|
||||
|
||||
d = 128
|
||||
freqs = mx.concatenate(
|
||||
@@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs:
|
||||
|
||||
def test_precomputed_matches_online(self):
|
||||
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
|
||||
from mlx_video.models.wan2.rope import (
|
||||
from mlx_video.models.wan_2.rope import (
|
||||
rope_apply,
|
||||
rope_params,
|
||||
rope_precompute_cos_sin,
|
||||
|
||||
Reference in New Issue
Block a user