This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -2,15 +2,13 @@
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Integration: end-to-end tiny model forward pass
# ---------------------------------------------------------------------------
class TestEndToEnd:
"""End-to-end test with tiny model (no real weights needed)."""
@@ -78,6 +76,7 @@ class TestEndToEnd:
# I2V Mask Tests
# ---------------------------------------------------------------------------
class TestI2VMask:
"""Tests for _build_i2v_mask."""
@@ -113,6 +112,7 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions."""
from mlx_video.generate_wan import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
z_shape = (48, 21, 44, 80)
@@ -133,6 +133,7 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.generate_wan import _build_i2v_mask
z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2)
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
@@ -144,13 +145,16 @@ class TestI2VMaskAlignment:
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
np.testing.assert_allclose(
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
)
# ---------------------------------------------------------------------------
# Dimension Alignment Tests
# ---------------------------------------------------------------------------
class TestDimensionAlignment:
"""Tests for automatic dimension alignment in generate_wan."""
@@ -198,6 +202,7 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan.model import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -222,11 +227,16 @@ class TestDimensionAlignment:
patches, grid_size = model._patchify(vid)
mx.eval(patches)
assert patches.ndim == 3 # [1, L, dim]
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
assert grid_size == (
t_latent,
h_latent // patch_size[1],
w_latent // patch_size[2],
)
def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1]
align_w = config.patch_size[2] * config.vae_stride[2]