format
This commit is contained in:
@@ -2,24 +2,25 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoPE:
|
||||
"""Tests for 3-way factorized RoPE."""
|
||||
|
||||
def test_rope_params_shape(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
freqs = rope_params(1024, 64)
|
||||
mx.eval(freqs)
|
||||
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
|
||||
|
||||
def test_rope_params_different_dims(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
for dim in [32, 64, 128]:
|
||||
freqs = rope_params(512, dim)
|
||||
mx.eval(freqs)
|
||||
@@ -27,6 +28,7 @@ class TestRoPE:
|
||||
|
||||
def test_rope_params_cos_sin_range(self):
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
freqs = rope_params(256, 64)
|
||||
mx.eval(freqs)
|
||||
cos_vals = np.array(freqs[:, :, 0])
|
||||
@@ -37,13 +39,15 @@ class TestRoPE:
|
||||
def test_rope_params_position_zero(self):
|
||||
"""At position 0, cos should be 1 and sin should be 0."""
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
freqs = rope_params(10, 64)
|
||||
mx.eval(freqs)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
|
||||
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
|
||||
|
||||
def test_rope_apply_output_shape(self):
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
|
||||
x = mx.random.normal((B, L, N, D))
|
||||
freqs = rope_params(1024, D)
|
||||
@@ -54,7 +58,8 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_preserves_norm(self):
|
||||
"""RoPE rotation should preserve vector norms."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 3, 4
|
||||
L = F * H * W
|
||||
@@ -74,7 +79,8 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_with_padding(self):
|
||||
"""When seq_len < L, extra tokens should be preserved unchanged."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 1, 2, 16
|
||||
F, H, W = 2, 2, 2
|
||||
seq_len = F * H * W # 8
|
||||
@@ -94,7 +100,8 @@ class TestRoPE:
|
||||
|
||||
def test_rope_apply_batch(self):
|
||||
"""Test with batch_size > 1 and different grid sizes."""
|
||||
from mlx_video.models.wan.rope import rope_params, rope_apply
|
||||
from mlx_video.models.wan.rope import rope_apply, rope_params
|
||||
|
||||
B, N, D = 2, 2, 16
|
||||
grids = [(2, 3, 4), (2, 3, 4)]
|
||||
L = 2 * 3 * 4
|
||||
@@ -122,9 +129,11 @@ class TestRoPE:
|
||||
# Attention Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanRMSNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
@@ -134,6 +143,7 @@ class TestWanRMSNorm:
|
||||
def test_zero_mean_variance(self):
|
||||
"""RMS norm should make RMS ≈ 1 before scaling."""
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(64)
|
||||
x = mx.random.normal((1, 5, 64)) * 10.0
|
||||
out = norm(x)
|
||||
@@ -147,6 +157,7 @@ class TestWanRMSNorm:
|
||||
def test_dtype_preservation(self):
|
||||
"""RMSNorm weight is float32, so output is promoted to float32."""
|
||||
from mlx_video.models.wan.attention import WanRMSNorm
|
||||
|
||||
norm = WanRMSNorm(32)
|
||||
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
|
||||
out = norm(x)
|
||||
@@ -158,6 +169,7 @@ class TestWanRMSNorm:
|
||||
class TestWanLayerNorm:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64)
|
||||
x = mx.random.normal((2, 10, 64))
|
||||
out = norm(x)
|
||||
@@ -166,6 +178,7 @@ class TestWanLayerNorm:
|
||||
|
||||
def test_without_affine(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(64, elementwise_affine=False)
|
||||
x = mx.random.normal((1, 4, 64))
|
||||
out = norm(x)
|
||||
@@ -178,6 +191,7 @@ class TestWanLayerNorm:
|
||||
|
||||
def test_with_affine(self):
|
||||
from mlx_video.models.wan.attention import WanLayerNorm
|
||||
|
||||
norm = WanLayerNorm(32, elementwise_affine=True)
|
||||
assert hasattr(norm, "weight")
|
||||
assert hasattr(norm, "bias")
|
||||
@@ -196,6 +210,7 @@ class TestWanSelfAttention:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
@@ -207,12 +222,14 @@ class TestWanSelfAttention:
|
||||
|
||||
def test_with_qk_norm(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
|
||||
assert attn.norm_q is not None
|
||||
assert attn.norm_k is not None
|
||||
|
||||
def test_without_qk_norm(self):
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
assert attn.norm_q is None
|
||||
assert attn.norm_k is None
|
||||
@@ -221,6 +238,7 @@ class TestWanSelfAttention:
|
||||
"""Test that masking works: shorter seq_lens should mask later tokens."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
|
||||
B, L = 1, 24
|
||||
F, H, W = 2, 3, 4
|
||||
@@ -245,6 +263,7 @@ class TestWanCrossAttention:
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 24, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
@@ -255,6 +274,7 @@ class TestWanCrossAttention:
|
||||
|
||||
def test_with_context_mask(self):
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
B, L_q, L_kv = 1, 12, 16
|
||||
x = mx.random.normal((B, L_q, self.dim))
|
||||
@@ -268,6 +288,7 @@ class TestWanCrossAttention:
|
||||
# bfloat16 Autocast Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBFloat16Autocast:
|
||||
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
|
||||
for efficient matmul, matching official PyTorch autocast behavior."""
|
||||
@@ -292,6 +313,7 @@ class TestBFloat16Autocast:
|
||||
"""Self-attention should cast input to weight dtype for QKV projections."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -305,6 +327,7 @@ class TestBFloat16Autocast:
|
||||
def test_cross_attn_casts_to_weight_dtype(self):
|
||||
"""Cross-attention should cast input to weight dtype."""
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -318,6 +341,7 @@ class TestBFloat16Autocast:
|
||||
def test_cross_attn_kv_cache_uses_weight_dtype(self):
|
||||
"""prepare_kv should cast context to weight dtype."""
|
||||
from mlx_video.models.wan.attention import WanCrossAttention
|
||||
|
||||
attn = WanCrossAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -330,6 +354,7 @@ class TestBFloat16Autocast:
|
||||
def test_ffn_casts_to_weight_dtype(self):
|
||||
"""FFN should cast input to weight dtype for linear layers."""
|
||||
from mlx_video.models.wan.transformer import WanFFN
|
||||
|
||||
ffn = WanFFN(self.dim, 128)
|
||||
ffn.update(self._to_bf16(ffn.parameters()))
|
||||
|
||||
@@ -343,6 +368,7 @@ class TestBFloat16Autocast:
|
||||
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
|
||||
from mlx_video.models.wan.attention import WanSelfAttention
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
|
||||
attn = WanSelfAttention(self.dim, self.num_heads)
|
||||
attn.update(self._to_bf16(attn.parameters()))
|
||||
|
||||
@@ -355,8 +381,9 @@ class TestBFloat16Autocast:
|
||||
|
||||
def test_block_float32_residual_with_bf16_weights(self):
|
||||
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
from mlx_video.models.wan.rope import rope_params
|
||||
from mlx_video.models.wan.transformer import WanAttentionBlock
|
||||
|
||||
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
|
||||
block.update(self._to_bf16(block.parameters()))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user