feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -868,4 +868,84 @@ class TestVAEEncoderTemporalOrder:
|
||||
assert out_wrong.shape[1] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VAE Encode → Decode Round-Trip Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVAE21RoundTrip:
|
||||
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
|
||||
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
|
||||
|
||||
z_dim = 4
|
||||
dim = 8
|
||||
# No temporal up/downsampling to keep the test simple
|
||||
enc = Encoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
|
||||
)
|
||||
dec = Decoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
|
||||
)
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, C=3, T=1, H=8, W=8]
|
||||
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
|
||||
|
||||
z = enc(x)
|
||||
mx.eval(z)
|
||||
# 3 spatial downsamples (÷8): H=1, W=1
|
||||
assert z.shape == (1, z_dim, 1, 1, 1)
|
||||
|
||||
x_hat = dec(z)
|
||||
mx.eval(x_hat)
|
||||
# 3 spatial upsamples (×8): should recover original shape
|
||||
assert x_hat.shape == x.shape
|
||||
|
||||
out_np = np.array(x_hat)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert np.abs(out_np).max() < 1000
|
||||
|
||||
|
||||
class TestVAE22RoundTrip:
|
||||
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
|
||||
from mlx_video.models.wan.vae22 import (
|
||||
Wan22VAEDecoder,
|
||||
Wan22VAEEncoder,
|
||||
denormalize_latents,
|
||||
)
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, T=1, H=32, W=32, C=3]
|
||||
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
|
||||
|
||||
z_norm = enc(img)
|
||||
mx.eval(z_norm)
|
||||
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
|
||||
assert z_norm.shape == (1, 1, 2, 2, 48)
|
||||
|
||||
z = denormalize_latents(z_norm)
|
||||
out = dec(z)
|
||||
mx.eval(out)
|
||||
|
||||
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
||||
assert out.shape[0] == 1 # batch
|
||||
assert out.shape[2] == 32 # H recovered
|
||||
assert out.shape[3] == 32 # W recovered
|
||||
assert out.shape[-1] == 3 # RGB
|
||||
|
||||
out_np = np.array(out)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert out_np.min() >= -1.0 - 1e-6
|
||||
assert out_np.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user