588 lines
21 KiB
Python
588 lines
21 KiB
Python
"""Tests for Wan2.2 I2V-14B support."""
|
|
|
|
import mlx.core as mx
|
|
from wan_test_helpers import _make_tiny_config
|
|
|
|
|
|
def _make_tiny_i2v_config():
|
|
"""Create a tiny I2V-14B config for testing."""
|
|
config = _make_tiny_config()
|
|
config.model_type = "i2v"
|
|
config.in_dim = 9 # 4 noise + 4 image + 1 mask (scaled down from 16+16+4=36)
|
|
config.out_dim = 4
|
|
config.vae_z_dim = 4
|
|
config.vae_stride = (4, 8, 8)
|
|
config.dual_model = True
|
|
config.boundary = 0.900
|
|
config.sample_shift = 5.0
|
|
config.sample_guide_scale = (3.5, 3.5)
|
|
return config
|
|
|
|
|
|
class TestI2VConfig:
|
|
"""Test I2V-14B config preset."""
|
|
|
|
def test_wan22_i2v_14b_preset(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan22_i2v_14b()
|
|
assert config.model_type == "i2v"
|
|
assert config.in_dim == 36
|
|
assert config.out_dim == 16
|
|
assert config.dim == 5120
|
|
assert config.num_layers == 40
|
|
assert config.dual_model is True
|
|
assert config.boundary == 0.900
|
|
assert config.sample_shift == 5.0
|
|
assert config.sample_guide_scale == (3.5, 3.5)
|
|
assert config.vae_stride == (4, 8, 8)
|
|
assert config.vae_z_dim == 16
|
|
|
|
def test_i2v_vs_t2v_differences(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
|
|
i2v = WanModelConfig.wan22_i2v_14b()
|
|
t2v = WanModelConfig.wan22_t2v_14b()
|
|
|
|
assert i2v.model_type == "i2v"
|
|
assert t2v.model_type == "t2v"
|
|
assert i2v.in_dim == 36 and t2v.in_dim == 16
|
|
assert i2v.boundary == 0.900 and t2v.boundary == 0.875
|
|
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
|
|
|
|
def test_i2v_serialization_roundtrip(self):
|
|
from mlx_video.models.wan.config import WanModelConfig
|
|
|
|
config = WanModelConfig.wan22_i2v_14b()
|
|
d = config.to_dict()
|
|
restored = WanModelConfig.from_dict(d)
|
|
assert restored.model_type == "i2v"
|
|
assert restored.in_dim == 36
|
|
assert restored.boundary == 0.900
|
|
|
|
|
|
class TestModelYParameter:
|
|
"""Test y parameter channel concatenation in WanModel."""
|
|
|
|
def test_forward_without_y(self):
|
|
"""Standard T2V forward pass (no y) still works."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
x_list = [mx.random.normal((C, F, H, W))]
|
|
t = mx.array([500.0])
|
|
context = [mx.random.normal((6, config.text_dim))]
|
|
|
|
out = model(x_list, t, context, seq_len)
|
|
mx.eval(out[0])
|
|
assert out[0].shape == (C, F, H, W)
|
|
|
|
def test_forward_with_y(self):
|
|
"""I2V forward pass with y channel concatenation."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_i2v_config()
|
|
model = WanModel(config)
|
|
|
|
C_noise = 4 # noise channels
|
|
C_y = 5 # mask (1) + image (4)
|
|
F, H, W = 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
x_list = [mx.random.normal((C_noise, F, H, W))]
|
|
y_list = [mx.random.normal((C_y, F, H, W))]
|
|
t = mx.array([500.0])
|
|
context = [mx.random.normal((6, config.text_dim))]
|
|
|
|
out = model(x_list, t, context, seq_len, y=y_list)
|
|
mx.eval(out[0])
|
|
# Output should match noise channels (out_dim), not concatenated in_dim
|
|
assert out[0].shape == (config.out_dim, F, H, W)
|
|
|
|
def test_y_none_is_noop(self):
|
|
"""Passing y=None should be identical to not passing y."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_config()
|
|
model = WanModel(config)
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
mx.random.seed(42)
|
|
x = mx.random.normal((C, F, H, W))
|
|
t = mx.array([500.0])
|
|
ctx = [mx.random.normal((6, config.text_dim))]
|
|
|
|
out1 = model([x], t, ctx, seq_len)[0]
|
|
out2 = model([x], t, ctx, seq_len, y=None)[0]
|
|
mx.eval(out1, out2)
|
|
assert mx.allclose(out1, out2, atol=1e-5).item()
|
|
|
|
def test_batched_cfg_with_y(self):
|
|
"""Batched CFG (B=2) with y should work."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
|
|
config = _make_tiny_i2v_config()
|
|
model = WanModel(config)
|
|
|
|
C_noise, C_y = 4, 5
|
|
F, H, W = 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
latents = mx.random.normal((C_noise, F, H, W))
|
|
y = mx.random.normal((C_y, F, H, W))
|
|
t = mx.array([500.0, 500.0])
|
|
ctx = [
|
|
mx.random.normal((6, config.text_dim)),
|
|
mx.random.normal((6, config.text_dim)),
|
|
]
|
|
|
|
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
|
|
mx.eval(out[0], out[1])
|
|
assert len(out) == 2
|
|
assert out[0].shape == (config.out_dim, F, H, W)
|
|
assert out[1].shape == (config.out_dim, F, H, W)
|
|
|
|
|
|
class TestVAEEncoder:
|
|
"""Test Wan2.1 VAE encoder."""
|
|
|
|
def test_encoder3d_instantiation(self):
|
|
from mlx_video.models.wan.vae import Encoder3d
|
|
|
|
enc = Encoder3d(
|
|
dim=32, z_dim=8
|
|
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
|
|
assert enc.conv1 is not None
|
|
assert len(enc.downsamples) > 0
|
|
assert len(enc.middle) == 3
|
|
|
|
def test_encoder3d_output_shape(self):
|
|
"""Encoder should downsample spatially by 8x and temporally by 4x."""
|
|
from mlx_video.models.wan.vae import Encoder3d
|
|
|
|
enc = Encoder3d(dim=32, z_dim=8)
|
|
# Random input: [B=1, 3, T=5, H=32, W=32]
|
|
x = mx.random.normal((1, 3, 5, 32, 32))
|
|
out = enc(x)
|
|
mx.eval(out)
|
|
# With default dim_mult=[1,2,4,4] and temporal_downsample=[True,True,False]:
|
|
# Spatial: 32 -> 16 -> 8 -> 4 (3 spatial downsamples)
|
|
# Temporal: 5 -> 3 -> 2 (2 temporal downsamples: downsample3d stride 2)
|
|
assert out.shape[0] == 1
|
|
assert out.shape[1] == 8 # z_dim
|
|
assert out.shape[3] == 32 // 8 # spatial /8
|
|
assert out.shape[4] == 32 // 8
|
|
|
|
def test_wan_vae_encode(self):
|
|
"""WanVAE with encoder=True should produce normalized latents."""
|
|
from mlx_video.models.wan.vae import WanVAE
|
|
|
|
vae = WanVAE(z_dim=16, encoder=True)
|
|
# Input: [B=1, 3, T=5, H=32, W=32]
|
|
x = mx.random.normal((1, 3, 5, 32, 32))
|
|
z = vae.encode(x)
|
|
mx.eval(z)
|
|
assert z.shape[0] == 1
|
|
assert z.shape[1] == 16 # z_dim
|
|
|
|
def test_wan_vae_encoder_flag(self):
|
|
"""WanVAE without encoder flag should not have encoder attribute."""
|
|
from mlx_video.models.wan.vae import WanVAE
|
|
|
|
vae_no_enc = WanVAE(z_dim=4, encoder=False)
|
|
assert not hasattr(vae_no_enc, "encoder")
|
|
|
|
vae_enc = WanVAE(z_dim=4, encoder=True)
|
|
assert hasattr(vae_enc, "encoder")
|
|
|
|
|
|
class TestResampleDownsample:
|
|
"""Test downsample modes in Resample."""
|
|
|
|
def test_downsample2d(self):
|
|
from mlx_video.models.wan.vae import Resample
|
|
|
|
r = Resample(dim=16, mode="downsample2d")
|
|
x = mx.random.normal((1, 16, 2, 8, 8))
|
|
out = r(x)
|
|
mx.eval(out)
|
|
# Spatial /2, temporal unchanged, channels same
|
|
assert out.shape == (1, 16, 2, 4, 4)
|
|
|
|
def test_downsample3d(self):
|
|
from mlx_video.models.wan.vae import Resample
|
|
|
|
r = Resample(dim=16, mode="downsample3d")
|
|
x = mx.random.normal((1, 16, 4, 8, 8))
|
|
out = r(x)
|
|
mx.eval(out)
|
|
# Spatial /2, temporal /2, channels same
|
|
assert out.shape == (1, 16, 2, 4, 4)
|
|
|
|
def test_upsample2d_still_works(self):
|
|
from mlx_video.models.wan.vae import Resample
|
|
|
|
r = Resample(dim=16, mode="upsample2d")
|
|
x = mx.random.normal((1, 16, 2, 4, 4))
|
|
out = r(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 8, 2, 8, 8)
|
|
|
|
def test_upsample3d_still_works(self):
|
|
from mlx_video.models.wan.vae import Resample
|
|
|
|
r = Resample(dim=16, mode="upsample3d")
|
|
x = mx.random.normal((1, 16, 2, 4, 4))
|
|
out = r(x)
|
|
mx.eval(out)
|
|
assert out.shape == (1, 8, 4, 8, 8)
|
|
|
|
|
|
class TestI2VMaskConstruction:
|
|
"""Test mask construction for I2V-14B."""
|
|
|
|
def test_mask_shape(self):
|
|
"""I2V-14B mask should have 4 channels with correct temporal structure."""
|
|
num_frames = 81
|
|
h_latent, w_latent = 10, 18 # example latent dims
|
|
t_latent = (num_frames - 1) // 4 + 1 # = 21
|
|
|
|
# Build mask following reference logic
|
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
|
msk = mx.concatenate(
|
|
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
|
)
|
|
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
|
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
|
|
|
assert msk.shape == (4, t_latent, h_latent, w_latent)
|
|
|
|
def test_mask_values(self):
|
|
"""First temporal position should be 1, rest 0."""
|
|
num_frames = 9
|
|
h_latent, w_latent = 4, 4
|
|
t_latent = (num_frames - 1) // 4 + 1 # = 3
|
|
|
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
|
msk = mx.concatenate(
|
|
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
|
)
|
|
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
|
msk = msk.transpose(0, 2, 1, 3, 4)[0]
|
|
|
|
mx.eval(msk)
|
|
# First temporal position: all 4 channels should be 1
|
|
assert mx.all(msk[:, 0] == 1.0).item()
|
|
# Rest: all should be 0
|
|
assert mx.all(msk[:, 1:] == 0.0).item()
|
|
|
|
def test_y_tensor_shape(self):
|
|
"""y = concat([mask_4ch, encoded_video_16ch]) should be 20 channels."""
|
|
mask = mx.zeros((4, 5, 10, 18))
|
|
encoded = mx.zeros((16, 5, 10, 18))
|
|
y = mx.concatenate([mask, encoded], axis=0)
|
|
assert y.shape == (20, 5, 10, 18)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration: I2V end-to-end pipeline
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestI2VEndToEndPipeline:
|
|
"""Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode."""
|
|
|
|
def test_full_i2v_pipeline(self):
|
|
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
from mlx_video.models.wan.vae import WanVAE
|
|
|
|
mx.random.seed(0)
|
|
|
|
# --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) ---
|
|
config = _make_tiny_i2v_config()
|
|
config.vae_z_dim = 16
|
|
config.out_dim = 16 # must match VAE z_dim for decode
|
|
config.in_dim = (
|
|
16 + 4 + 16
|
|
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
|
model = WanModel(config)
|
|
|
|
# --- Tiny VAE (with encoder) ---
|
|
vae = WanVAE(z_dim=config.vae_z_dim, encoder=True)
|
|
|
|
# --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] ---
|
|
height, width = 32, 32
|
|
num_frames = 5 # small temporal extent
|
|
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
|
|
|
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
|
video = mx.concatenate(
|
|
[
|
|
img,
|
|
mx.zeros((1, 3, num_frames - 1, height, width)),
|
|
],
|
|
axis=2,
|
|
)
|
|
|
|
# --- VAE encode ---
|
|
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
|
mx.eval(z_video)
|
|
assert z_video.ndim == 5
|
|
assert z_video.shape[1] == config.vae_z_dim
|
|
|
|
z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat]
|
|
t_latent = z_video.shape[1]
|
|
h_latent = z_video.shape[2]
|
|
w_latent = z_video.shape[3]
|
|
|
|
# --- Build I2V mask (4 channels) ---
|
|
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
|
msk = mx.concatenate(
|
|
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
|
)
|
|
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
|
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
|
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
|
|
|
# --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] ---
|
|
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
|
mx.eval(y_i2v)
|
|
assert y_i2v.shape[0] == 4 + config.vae_z_dim
|
|
|
|
# --- Denoising loop (2 steps) ---
|
|
C_noise = config.out_dim # noise channels
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw)
|
|
|
|
sched = FlowMatchEulerScheduler()
|
|
num_steps = 2
|
|
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
|
|
|
latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
|
|
for i in range(num_steps):
|
|
t_val = sched.timesteps[i].item()
|
|
pred = model(
|
|
[latents],
|
|
mx.array([t_val]),
|
|
[context],
|
|
seq_len,
|
|
y=[y_i2v],
|
|
)[0]
|
|
latents = sched.step(pred[None], t_val, latents[None]).squeeze(0)
|
|
mx.eval(latents)
|
|
|
|
assert latents.shape == (C_noise, t_latent, h_latent, w_latent)
|
|
assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents"
|
|
assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents"
|
|
|
|
# --- VAE decode ---
|
|
decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out]
|
|
mx.eval(decoded)
|
|
assert decoded.ndim == 5
|
|
assert decoded.shape[0] == 1
|
|
assert decoded.shape[1] == 3 # RGB output
|
|
assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video"
|
|
assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video"
|
|
# VAE decode clips to [-1, 1]
|
|
assert float(decoded.max()) <= 1.0
|
|
assert float(decoded.min()) >= -1.0
|
|
|
|
|
|
class TestDualModelSwitching:
|
|
"""Test dual-model selection logic: high_noise vs low_noise based on boundary."""
|
|
|
|
def test_model_selection_by_timestep(self):
|
|
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
|
|
mx.random.seed(1)
|
|
config = _make_tiny_i2v_config()
|
|
assert config.dual_model is True
|
|
|
|
high_noise_model = WanModel(config)
|
|
low_noise_model = WanModel(config)
|
|
|
|
boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900
|
|
|
|
C_noise = config.out_dim # 4
|
|
C_y = config.in_dim - config.out_dim # 9 - 4 = 5
|
|
F, H, W = 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
sched = FlowMatchEulerScheduler()
|
|
num_steps = 5
|
|
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
|
|
|
guide_scale = config.sample_guide_scale # (3.5, 3.5)
|
|
assert isinstance(guide_scale, tuple) and len(guide_scale) == 2
|
|
|
|
latents = mx.random.normal((C_noise, F, H, W))
|
|
y_i2v = mx.random.normal((C_y, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
|
|
high_used_steps = []
|
|
low_used_steps = []
|
|
|
|
timestep_list = sched.timesteps.tolist()
|
|
for i in range(num_steps):
|
|
timestep_val = timestep_list[i]
|
|
|
|
if timestep_val >= boundary:
|
|
model = high_noise_model
|
|
gs = guide_scale[1]
|
|
high_used_steps.append(i)
|
|
else:
|
|
model = low_noise_model
|
|
gs = guide_scale[0]
|
|
low_used_steps.append(i)
|
|
|
|
# CFG pass: cond + uncond
|
|
preds = model(
|
|
[latents, latents],
|
|
mx.array([timestep_val, timestep_val]),
|
|
[context, context],
|
|
seq_len,
|
|
y=[y_i2v, y_i2v],
|
|
)
|
|
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
|
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
|
|
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
|
0
|
|
)
|
|
mx.eval(latents)
|
|
|
|
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
|
assert len(high_used_steps) > 0, "High-noise model was never selected"
|
|
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
|
# High-noise steps should come before low-noise steps (timesteps decrease)
|
|
if high_used_steps and low_used_steps:
|
|
assert max(high_used_steps) < min(low_used_steps) or min(
|
|
high_used_steps
|
|
) < max(low_used_steps), "Model switching should happen during the loop"
|
|
|
|
assert latents.shape == (C_noise, F, H, W)
|
|
assert not mx.any(mx.isnan(latents)).item()
|
|
|
|
def test_guide_scale_tuple_applied_per_model(self):
|
|
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
|
|
mx.random.seed(2)
|
|
config = _make_tiny_i2v_config()
|
|
config.sample_guide_scale = (2.0, 5.0) # distinct values
|
|
|
|
model = WanModel(config)
|
|
boundary = config.boundary * config.num_train_timesteps
|
|
|
|
C_noise = config.out_dim
|
|
F, H, W = 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(5, shift=config.sample_shift)
|
|
|
|
latents = mx.random.normal((C_noise, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
guide_scale = config.sample_guide_scale
|
|
C_y = config.in_dim - config.out_dim # y channels
|
|
y_i2v = mx.random.normal((C_y, F, H, W))
|
|
|
|
# Track which guide scale was used at each step
|
|
gs_per_step = []
|
|
|
|
timestep_list = sched.timesteps.tolist()
|
|
for i in range(5):
|
|
timestep_val = timestep_list[i]
|
|
|
|
if timestep_val >= boundary:
|
|
gs = guide_scale[1] # high_gs = 5.0
|
|
else:
|
|
gs = guide_scale[0] # low_gs = 2.0
|
|
gs_per_step.append(gs)
|
|
|
|
pred = model(
|
|
[latents, latents],
|
|
mx.array([timestep_val, timestep_val]),
|
|
[context, context],
|
|
seq_len,
|
|
y=[y_i2v, y_i2v],
|
|
)
|
|
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
|
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
|
|
0
|
|
)
|
|
mx.eval(latents)
|
|
|
|
# Verify both guide scales were used
|
|
assert 5.0 in gs_per_step, "High guide scale (5.0) was never used"
|
|
assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used"
|
|
# High gs should appear first (high timesteps come first)
|
|
first_high = gs_per_step.index(5.0)
|
|
last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0)
|
|
assert first_high < last_low, "High gs steps should precede low gs steps"
|
|
|
|
def test_single_model_fallback_with_tuple_guide_scale(self):
|
|
"""When dual_model=False, guide_scale tuple should use first element."""
|
|
from mlx_video.models.wan.model import WanModel
|
|
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
|
|
|
mx.random.seed(3)
|
|
config = _make_tiny_config()
|
|
config.dual_model = False
|
|
config.sample_guide_scale = (3.0, 5.0)
|
|
|
|
model = WanModel(config)
|
|
guide_scale = config.sample_guide_scale
|
|
|
|
C, F, H, W = config.in_dim, 1, 4, 4
|
|
pt, ph, pw = config.patch_size
|
|
seq_len = (F // pt) * (H // ph) * (W // pw)
|
|
|
|
sched = FlowMatchEulerScheduler()
|
|
sched.set_timesteps(3, shift=3.0)
|
|
|
|
latents = mx.random.normal((C, F, H, W))
|
|
context = mx.random.normal((4, config.text_dim))
|
|
|
|
# Mimic generate_wan.py single-model logic:
|
|
# gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
|
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
|
assert gs == 3.0, "Single model should use first element of guide_scale tuple"
|
|
|
|
for i in range(3):
|
|
t_val = sched.timesteps[i].item()
|
|
pred = model(
|
|
[latents, latents],
|
|
mx.array([t_val, t_val]),
|
|
[context, context],
|
|
seq_len,
|
|
)
|
|
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
|
latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0)
|
|
mx.eval(latents)
|
|
|
|
assert latents.shape == (C, F, H, W)
|
|
assert not mx.any(mx.isnan(latents)).item()
|