Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction:
def _get_model_freqs(self, dim=64, num_heads=4):
"""Instantiate a tiny WanModel and return its .freqs tensor."""
from mlx_video.models.wan2.config import WanModelConfig
from mlx_video.models.wan2.wan2 import WanModel
from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan_2.wan_2 import WanModel
config = WanModelConfig()
config.dim = dim
@@ -51,7 +51,7 @@ class TestRoPEFrequencyConstruction:
def test_three_call_vs_single_call_differ(self):
"""Three separate rope_params calls must differ from single call."""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128 # head_dim for all Wan models
# Reference: three separate calls
@@ -79,7 +79,7 @@ class TestRoPEFrequencyConstruction:
This verifies each axis gets its own independent frequency range
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
"""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -120,7 +120,7 @@ class TestRoPEFrequencyConstruction:
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
"""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
d_h_dim = 2 * (d // 6) # 42
@@ -150,7 +150,7 @@ class TestRoPEFrequencyConstruction:
axis should be 1.0 (theta^0). A single-call approach would give height
starting at ~0.04 and width at ~0.002 instead of 1.0.
"""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -182,7 +182,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_match_manual_construction(self):
"""WanModel.freqs should match manually constructed three-call freqs."""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
@@ -203,7 +203,7 @@ class TestRoPEFrequencyConstruction:
def test_model_freqs_14b_dimensions(self):
"""Verify freq dimensions for 14B-scale head_dim=128."""
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
freqs = mx.concatenate(
@@ -242,7 +242,7 @@ class TestRoPEFrequencyMatchesReference:
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan2.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
@@ -298,7 +298,7 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan2.rope import rope_apply, rope_params
from mlx_video.models.wan_2.rope import rope_apply, rope_params
d = 128
freqs = mx.concatenate(
@@ -346,7 +346,7 @@ class TestRoPEApplyWithCorrectFreqs:
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
from mlx_video.models.wan2.rope import (
from mlx_video.models.wan_2.rope import (
rope_apply,
rope_params,
rope_precompute_cos_sin,