format
This commit is contained in:
@@ -2,13 +2,11 @@
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
decode_with_tiling,
|
||||
split_in_spatial,
|
||||
split_in_temporal,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,16 +47,24 @@ class TestNonCausalTemporal:
|
||||
|
||||
# Causal: 1 + (4-1)*4 = 13
|
||||
out_causal = decode_with_tiling(
|
||||
dummy_decoder_causal, latents, config,
|
||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
|
||||
dummy_decoder_causal,
|
||||
latents,
|
||||
config,
|
||||
spatial_scale=scale,
|
||||
temporal_scale=scale,
|
||||
causal_temporal=True,
|
||||
)
|
||||
mx.eval(out_causal)
|
||||
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
|
||||
|
||||
# Non-causal: 4*4 = 16
|
||||
out_noncausal = decode_with_tiling(
|
||||
dummy_decoder_noncausal, latents, config,
|
||||
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
|
||||
dummy_decoder_noncausal,
|
||||
latents,
|
||||
config,
|
||||
spatial_scale=scale,
|
||||
temporal_scale=scale,
|
||||
causal_temporal=False,
|
||||
)
|
||||
mx.eval(out_noncausal)
|
||||
assert out_noncausal.shape[2] == t * scale # 16
|
||||
@@ -100,9 +106,9 @@ class TestWan22TiledDecoding:
|
||||
mx.eval(out_tiled)
|
||||
|
||||
# Both should produce the same shape
|
||||
assert out_regular.shape == out_tiled.shape, (
|
||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
)
|
||||
assert (
|
||||
out_regular.shape == out_tiled.shape
|
||||
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
|
||||
@@ -120,8 +126,10 @@ class TestWan22TiledDecoding:
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular), np.array(out_tiled),
|
||||
rtol=1e-4, atol=1e-4,
|
||||
np.array(out_regular),
|
||||
np.array(out_tiled),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
@@ -152,9 +160,9 @@ class TestWan21TiledDecoding:
|
||||
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
|
||||
mx.eval(out_tiled)
|
||||
|
||||
assert out_regular.shape == out_tiled.shape, (
|
||||
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
)
|
||||
assert (
|
||||
out_regular.shape == out_tiled.shape
|
||||
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
|
||||
|
||||
def test_decode_tiled_falls_through_when_small(self):
|
||||
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
|
||||
@@ -171,8 +179,10 @@ class TestWan21TiledDecoding:
|
||||
mx.eval(out_tiled)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
np.array(out_regular), np.array(out_tiled),
|
||||
rtol=1e-4, atol=1e-4,
|
||||
np.array(out_regular),
|
||||
np.array(out_tiled),
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
err_msg="Tiled decode should match regular decode for small inputs",
|
||||
)
|
||||
|
||||
@@ -185,8 +195,13 @@ class TestWan21TemporalScale:
|
||||
from mlx_video.models.wan.vae import Decoder3d
|
||||
|
||||
# Small decoder for fast test
|
||||
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
|
||||
temporal_upsample=[True, True, False])
|
||||
dec = Decoder3d(
|
||||
dim=16,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temporal_upsample=[True, True, False],
|
||||
)
|
||||
mx.eval(dec.parameters())
|
||||
|
||||
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3
|
||||
|
||||
Reference in New Issue
Block a user