This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -2,13 +2,11 @@
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
decode_with_tiling,
split_in_spatial,
split_in_temporal,
)
@@ -49,16 +47,24 @@ class TestNonCausalTemporal:
# Causal: 1 + (4-1)*4 = 13
out_causal = decode_with_tiling(
dummy_decoder_causal, latents, config,
spatial_scale=scale, temporal_scale=scale, causal_temporal=True,
dummy_decoder_causal,
latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=True,
)
mx.eval(out_causal)
assert out_causal.shape[2] == 1 + (t - 1) * scale # 13
# Non-causal: 4*4 = 16
out_noncausal = decode_with_tiling(
dummy_decoder_noncausal, latents, config,
spatial_scale=scale, temporal_scale=scale, causal_temporal=False,
dummy_decoder_noncausal,
latents,
config,
spatial_scale=scale,
temporal_scale=scale,
causal_temporal=False,
)
mx.eval(out_noncausal)
assert out_noncausal.shape[2] == t * scale # 16
@@ -100,9 +106,9 @@ class TestWan22TiledDecoding:
mx.eval(out_tiled)
# Both should produce the same shape
assert out_regular.shape == out_tiled.shape, (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
)
assert (
out_regular.shape == out_tiled.shape
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as __call__."""
@@ -120,8 +126,10 @@ class TestWan22TiledDecoding:
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled),
rtol=1e-4, atol=1e-4,
np.array(out_regular),
np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
@@ -152,9 +160,9 @@ class TestWan21TiledDecoding:
out_tiled = vae.decode_tiled(z, tiling_config=TilingConfig.default())
mx.eval(out_tiled)
assert out_regular.shape == out_tiled.shape, (
f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
)
assert (
out_regular.shape == out_tiled.shape
), f"Shape mismatch: regular={out_regular.shape} vs tiled={out_tiled.shape}"
def test_decode_tiled_falls_through_when_small(self):
"""When input is smaller than tile size, decode_tiled should produce same output as decode."""
@@ -171,8 +179,10 @@ class TestWan21TiledDecoding:
mx.eval(out_tiled)
np.testing.assert_allclose(
np.array(out_regular), np.array(out_tiled),
rtol=1e-4, atol=1e-4,
np.array(out_regular),
np.array(out_tiled),
rtol=1e-4,
atol=1e-4,
err_msg="Tiled decode should match regular decode for small inputs",
)
@@ -185,8 +195,13 @@ class TestWan21TemporalScale:
from mlx_video.models.wan.vae import Decoder3d
# Small decoder for fast test
dec = Decoder3d(dim=16, z_dim=4, dim_mult=[1, 1, 1, 1], num_res_blocks=1,
temporal_upsample=[True, True, False])
dec = Decoder3d(
dim=16,
z_dim=4,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temporal_upsample=[True, True, False],
)
mx.eval(dec.parameters())
x = mx.random.normal((1, 4, 3, 4, 4)) # T=3