format
This commit is contained in:
@@ -1,9 +1,6 @@
|
||||
"""Tests for Wan2.2 I2V-14B support."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from wan_test_helpers import _make_tiny_config
|
||||
|
||||
|
||||
@@ -145,7 +142,10 @@ class TestModelYParameter:
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
y = mx.random.normal((C_y, F, H, W))
|
||||
t = mx.array([500.0, 500.0])
|
||||
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
|
||||
ctx = [
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
mx.random.normal((6, config.text_dim)),
|
||||
]
|
||||
|
||||
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
|
||||
mx.eval(out[0], out[1])
|
||||
@@ -160,7 +160,9 @@ class TestVAEEncoder:
|
||||
def test_encoder3d_instantiation(self):
|
||||
from mlx_video.models.wan.vae import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
||||
enc = Encoder3d(
|
||||
dim=32, z_dim=8
|
||||
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
||||
assert enc.conv1 is not None
|
||||
assert len(enc.downsamples) > 0
|
||||
assert len(enc.middle) == 3
|
||||
@@ -199,10 +201,10 @@ class TestVAEEncoder:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
||||
assert not hasattr(vae_no_enc, 'encoder')
|
||||
assert not hasattr(vae_no_enc, "encoder")
|
||||
|
||||
vae_enc = WanVAE(z_dim=4, encoder=True)
|
||||
assert hasattr(vae_enc, 'encoder')
|
||||
assert hasattr(vae_enc, "encoder")
|
||||
|
||||
|
||||
class TestResampleDownsample:
|
||||
@@ -258,7 +260,9 @@ class TestI2VMaskConstruction:
|
||||
|
||||
# Build mask following reference logic
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
@@ -272,7 +276,9 @@ class TestI2VMaskConstruction:
|
||||
t_latent = (num_frames - 1) // 4 + 1 # = 3
|
||||
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0]
|
||||
@@ -311,7 +317,9 @@ class TestI2VEndToEndPipeline:
|
||||
config = _make_tiny_i2v_config()
|
||||
config.vae_z_dim = 16
|
||||
config.out_dim = 16 # must match VAE z_dim for decode
|
||||
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||
config.in_dim = (
|
||||
16 + 4 + 16
|
||||
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||
model = WanModel(config)
|
||||
|
||||
# --- Tiny VAE (with encoder) ---
|
||||
@@ -323,10 +331,13 @@ class TestI2VEndToEndPipeline:
|
||||
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
||||
video = mx.concatenate([
|
||||
img,
|
||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||
], axis=2)
|
||||
video = mx.concatenate(
|
||||
[
|
||||
img,
|
||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
|
||||
# --- VAE encode ---
|
||||
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
||||
@@ -341,7 +352,9 @@ class TestI2VEndToEndPipeline:
|
||||
|
||||
# --- Build I2V mask (4 channels) ---
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
@@ -453,7 +466,9 @@ class TestDualModelSwitching:
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||
0
|
||||
)
|
||||
mx.eval(latents)
|
||||
|
||||
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
||||
@@ -461,9 +476,9 @@ class TestDualModelSwitching:
|
||||
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
||||
# High-noise steps should come before low-noise steps (timesteps decrease)
|
||||
if high_used_steps and low_used_steps:
|
||||
assert max(high_used_steps) < min(low_used_steps) or \
|
||||
min(high_used_steps) < max(low_used_steps), \
|
||||
"Model switching should happen during the loop"
|
||||
assert max(high_used_steps) < min(low_used_steps) or min(
|
||||
high_used_steps
|
||||
) < max(low_used_steps), "Model switching should happen during the loop"
|
||||
|
||||
assert latents.shape == (C_noise, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
@@ -515,7 +530,9 @@ class TestDualModelSwitching:
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
||||
0
|
||||
)
|
||||
mx.eval(latents)
|
||||
|
||||
# Verify both guide scales were used
|
||||
|
||||
Reference in New Issue
Block a user