feat(wan): Add LoRA with improved quantization pipeline

This commit is contained in:
Daniel
2026-02-28 14:11:13 +01:00
parent dbab95ec45
commit 849cc45d84
17 changed files with 1852 additions and 111 deletions

View File

@@ -868,4 +868,84 @@ class TestVAEEncoderTemporalOrder:
assert out_wrong.shape[1] == 2
# ---------------------------------------------------------------------------
# VAE Encode → Decode Round-Trip Tests
# ---------------------------------------------------------------------------
class TestVAE21RoundTrip:
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
def test_encode_decode_shape_and_values(self):
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
z_dim = 4
dim = 8
# No temporal up/downsampling to keep the test simple
enc = Encoder3d(
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
)
dec = Decoder3d(
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, C=3, T=1, H=8, W=8]
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
z = enc(x)
mx.eval(z)
# 3 spatial downsamples (÷8): H=1, W=1
assert z.shape == (1, z_dim, 1, 1, 1)
x_hat = dec(z)
mx.eval(x_hat)
# 3 spatial upsamples (×8): should recover original shape
assert x_hat.shape == x.shape
out_np = np.array(x_hat)
assert np.all(np.isfinite(out_np))
assert np.abs(out_np).max() < 1000
class TestVAE22RoundTrip:
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
def test_encode_decode_shape_and_values(self):
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
from mlx_video.models.wan.vae22 import (
Wan22VAEDecoder,
Wan22VAEEncoder,
denormalize_latents,
)
enc = Wan22VAEEncoder(z_dim=48, dim=16)
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
mx.eval(enc.parameters(), dec.parameters())
# [B=1, T=1, H=32, W=32, C=3]
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
z_norm = enc(img)
mx.eval(z_norm)
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
assert z_norm.shape == (1, 1, 2, 2, 48)
z = denormalize_latents(z_norm)
out = dec(z)
mx.eval(out)
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
assert out.shape[0] == 1 # batch
assert out.shape[2] == 32 # H recovered
assert out.shape[3] == 32 # W recovered
assert out.shape[-1] == 3 # RGB
out_np = np.array(out)
assert np.all(np.isfinite(out_np))
assert out_np.min() >= -1.0 - 1e-6
assert out_np.max() <= 1.0 + 1e-6