format
This commit is contained in:
@@ -3,18 +3,17 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sinusoidal Embedding Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSinusoidalEmbedding:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.arange(10).astype(mx.float32)
|
||||
emb = sinusoidal_embedding_1d(256, pos)
|
||||
mx.eval(emb)
|
||||
@@ -23,6 +22,7 @@ class TestSinusoidalEmbedding:
|
||||
def test_position_zero(self):
|
||||
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0])
|
||||
emb = sinusoidal_embedding_1d(64, pos)
|
||||
mx.eval(emb)
|
||||
@@ -34,6 +34,7 @@ class TestSinusoidalEmbedding:
|
||||
|
||||
def test_different_positions_differ(self):
|
||||
from mlx_video.models.wan.model import sinusoidal_embedding_1d
|
||||
|
||||
pos = mx.array([0.0, 100.0, 999.0])
|
||||
emb = sinusoidal_embedding_1d(128, pos)
|
||||
mx.eval(emb)
|
||||
@@ -46,9 +47,11 @@ class TestSinusoidalEmbedding:
|
||||
# Head Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHead:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
B, L = 1, 24
|
||||
x = mx.random.normal((B, L, 64))
|
||||
@@ -60,6 +63,7 @@ class TestHead:
|
||||
|
||||
def test_modulation_shape(self):
|
||||
from mlx_video.models.wan.model import Head
|
||||
|
||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||
assert head.modulation.shape == (1, 2, 64)
|
||||
|
||||
@@ -68,12 +72,14 @@ class TestHead:
|
||||
# WanModel (Tiny) Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWanModel:
|
||||
def setup_method(self):
|
||||
mx.random.seed(42)
|
||||
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
|
||||
@@ -81,6 +87,7 @@ class TestWanModel:
|
||||
|
||||
def test_patchify_shape(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
# Input: [C=4, F=1, H=4, W=4]
|
||||
@@ -93,6 +100,7 @@ class TestWanModel:
|
||||
|
||||
def test_patchify_various_sizes(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
|
||||
@@ -108,6 +116,7 @@ class TestWanModel:
|
||||
def test_unpatchify_inverse(self):
|
||||
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 2, 4, 6
|
||||
@@ -123,6 +132,7 @@ class TestWanModel:
|
||||
|
||||
def test_forward_pass(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
@@ -140,6 +150,7 @@ class TestWanModel:
|
||||
|
||||
def test_forward_batch(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
@@ -148,7 +159,10 @@ class TestWanModel:
|
||||
|
||||
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
|
||||
t = mx.array([500.0, 200.0])
|
||||
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
|
||||
context = [
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
mx.random.normal((4, config.text_dim)),
|
||||
]
|
||||
|
||||
out = model(x_list, t, context, seq_len)
|
||||
mx.eval(out[0], out[1])
|
||||
@@ -158,12 +172,17 @@ class TestWanModel:
|
||||
|
||||
def test_output_is_float32(self):
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
seq_len = (F // 1) * (H // 2) * (W // 2)
|
||||
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
|
||||
[mx.random.normal((4, config.text_dim))], seq_len)
|
||||
out = model(
|
||||
[mx.random.normal((C, F, H, W))],
|
||||
mx.array([100.0]),
|
||||
[mx.random.normal((4, config.text_dim))],
|
||||
seq_len,
|
||||
)
|
||||
mx.eval(out[0])
|
||||
assert out[0].dtype == mx.float32
|
||||
|
||||
@@ -172,6 +191,7 @@ class TestWanModel:
|
||||
# Wan2.1 Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWan21Model:
|
||||
"""Test tiny Wan2.1-style model (single model mode)."""
|
||||
|
||||
@@ -181,6 +201,7 @@ class TestWan21Model:
|
||||
def _make_tiny_wan21_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 (single model)."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
# Override to tiny values
|
||||
config.dim = 64
|
||||
@@ -197,6 +218,7 @@ class TestWan21Model:
|
||||
def _make_tiny_wan21_1_3b_config(self):
|
||||
"""Create a tiny config mimicking Wan2.1 1.3B."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
# Override to tiny values (preserve 1.3B head structure: 12 heads)
|
||||
config.dim = 48
|
||||
@@ -271,7 +293,9 @@ class TestWan21Model:
|
||||
for i in range(3):
|
||||
t = sched.timesteps[i]
|
||||
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
|
||||
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
|
||||
pred_uncond = model(
|
||||
[latents], mx.array([t.item()]), [context_null], seq_len
|
||||
)[0]
|
||||
pred = pred_uncond + gs * (pred_cond - pred_uncond)
|
||||
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
@@ -304,6 +328,7 @@ class TestWan21Model:
|
||||
# Per-Token Timestep Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerTokenTimestep:
|
||||
"""Tests for per-token sinusoidal embedding."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user