format
This commit is contained in:
@@ -2,15 +2,13 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: end-to-end tiny model forward pass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end test with tiny model (no real weights needed)."""
|
||||
|
||||
@@ -78,6 +76,7 @@ class TestEndToEnd:
|
||||
# I2V Mask Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestI2VMask:
|
||||
"""Tests for _build_i2v_mask."""
|
||||
|
||||
@@ -113,6 +112,7 @@ class TestI2VMaskAlignment:
|
||||
def test_mask_with_ti2v_dimensions(self):
|
||||
"""Mask should work with TI2V-5B typical dimensions."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
|
||||
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
|
||||
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
|
||||
z_shape = (48, 21, 44, 80)
|
||||
@@ -133,6 +133,7 @@ class TestI2VMaskAlignment:
|
||||
def test_mask_per_token_timestep(self):
|
||||
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
|
||||
from mlx_video.generate_wan import _build_i2v_mask
|
||||
|
||||
z_shape = (4, 3, 4, 4)
|
||||
patch_size = (1, 2, 2)
|
||||
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
|
||||
@@ -144,13 +145,16 @@ class TestI2VMaskAlignment:
|
||||
|
||||
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
|
||||
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
|
||||
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
|
||||
np.testing.assert_allclose(
|
||||
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension Alignment Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDimensionAlignment:
|
||||
"""Tests for automatic dimension alignment in generate_wan."""
|
||||
|
||||
@@ -198,6 +202,7 @@ class TestDimensionAlignment:
|
||||
def test_patchify_valid_after_alignment(self):
|
||||
"""After alignment, patchify should succeed without reshape errors."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
config = _make_tiny_config()
|
||||
model = WanModel(config)
|
||||
|
||||
@@ -222,11 +227,16 @@ class TestDimensionAlignment:
|
||||
patches, grid_size = model._patchify(vid)
|
||||
mx.eval(patches)
|
||||
assert patches.ndim == 3 # [1, L, dim]
|
||||
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
|
||||
assert grid_size == (
|
||||
t_latent,
|
||||
h_latent // patch_size[1],
|
||||
w_latent // patch_size[2],
|
||||
)
|
||||
|
||||
def test_alignment_with_ti2v_config(self):
|
||||
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config = WanModelConfig.wan22_ti2v_5b()
|
||||
align_h = config.patch_size[1] * config.vae_stride[1]
|
||||
align_w = config.patch_size[2] * config.vae_stride[2]
|
||||
|
||||
Reference in New Issue
Block a user