Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -10,7 +10,7 @@ import numpy as np
|
||||
|
||||
class TestWanFFN:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.transformer import WanFFN
|
||||
from mlx_video.models.wan_2.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(64, 256)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
@@ -20,7 +20,7 @@ class TestWanFFN:
|
||||
|
||||
def test_gelu_activation(self):
|
||||
"""FFN should use GELU activation (non-linearity)."""
|
||||
from mlx_video.models.wan2.transformer import WanFFN
|
||||
from mlx_video.models.wan_2.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(32, 128)
|
||||
x = mx.ones((1, 1, 32)) * 2.0
|
||||
@@ -40,8 +40,8 @@ class TestWanAttentionBlock:
|
||||
self.num_heads = 4
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim,
|
||||
@@ -68,13 +68,13 @@ class TestWanAttentionBlock:
|
||||
assert out.shape == (B, L, self.dim)
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
assert block.modulation.shape == (1, 6, self.dim)
|
||||
|
||||
def test_with_cross_attn_norm(self):
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim,
|
||||
@@ -85,7 +85,7 @@ class TestWanAttentionBlock:
|
||||
assert block.norm3 is not None
|
||||
|
||||
def test_without_cross_attn_norm(self):
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(
|
||||
self.dim,
|
||||
@@ -97,8 +97,8 @@ class TestWanAttentionBlock:
|
||||
|
||||
def test_residual_connection(self):
|
||||
"""Output should differ from zero even with small random init."""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, self.ffn_dim, self.num_heads)
|
||||
B, L = 1, 8
|
||||
@@ -129,15 +129,15 @@ class TestFloat32Modulation:
|
||||
|
||||
def test_block_modulation_in_float32(self):
|
||||
"""Modulation param starts random but should be usable as float32."""
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, 4, cross_attn_norm=True)
|
||||
assert block.modulation.dtype == mx.float32
|
||||
|
||||
def test_block_output_float32_with_bf16_modulation_input(self):
|
||||
"""Even if e (time embedding) arrives as bf16, modulation should cast to f32."""
|
||||
from mlx_video.models.wan2.rope import rope_params
|
||||
from mlx_video.models.wan2.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan_2.rope import rope_params
|
||||
from mlx_video.models.wan_2.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, 4)
|
||||
B, L = 1, 8
|
||||
@@ -153,7 +153,7 @@ class TestFloat32Modulation:
|
||||
|
||||
def test_head_modulation_float32(self):
|
||||
"""Head modulation should be float32 even with bf16 e input."""
|
||||
from mlx_video.models.wan2.wan2 import Head
|
||||
from mlx_video.models.wan_2.wan_2 import Head
|
||||
|
||||
head = Head(self.dim, 4, (1, 2, 2))
|
||||
x = mx.random.normal((1, 8, self.dim))
|
||||
@@ -164,7 +164,7 @@ class TestFloat32Modulation:
|
||||
|
||||
def test_model_time_embedding_float32(self):
|
||||
"""sinusoidal_embedding_1d output must be float32."""
|
||||
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
t = mx.array([500.0])
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
@@ -173,7 +173,7 @@ class TestFloat32Modulation:
|
||||
|
||||
def test_model_per_token_time_embedding_float32(self):
|
||||
"""Per-token time embeddings (I2V) should also be float32."""
|
||||
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
|
||||
|
||||
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
||||
emb = sinusoidal_embedding_1d(256, t)
|
||||
|
||||
Reference in New Issue
Block a user