feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""Tests for Wan weight conversion utilities."""
|
||||
|
||||
import logging
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -94,6 +96,27 @@ class TestSanitizeTransformerWeights:
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_transformer_weights
|
||||
weights = {
|
||||
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
|
||||
"patch_embedding.bias": mx.random.normal((5120,)),
|
||||
"text_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"text_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_embedding.0.weight": mx.zeros((64, 32)),
|
||||
"time_embedding.2.weight": mx.zeros((64, 64)),
|
||||
"time_projection.1.weight": mx.zeros((384, 64)),
|
||||
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.2.weight": mx.zeros((64, 128)),
|
||||
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
|
||||
"blocks.0.modulation": mx.zeros((1, 6, 64)),
|
||||
"head.head.weight": mx.zeros((64, 64)),
|
||||
"freqs": mx.zeros((1024, 64, 2)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_transformer_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeT5Weights:
|
||||
def test_gate_rename(self):
|
||||
@@ -119,6 +142,19 @@ class TestSanitizeT5Weights:
|
||||
for key in weights:
|
||||
assert key in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_t5_weights
|
||||
weights = {
|
||||
"token_embedding.weight": mx.zeros((100, 64)),
|
||||
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
|
||||
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
|
||||
"norm.weight": mx.zeros((64,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_t5_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
class TestSanitizeVAEWeights:
|
||||
def test_conv3d_transpose(self):
|
||||
@@ -161,6 +197,18 @@ class TestSanitizeVAEWeights:
|
||||
assert out["linear.weight"].shape == (8, 4)
|
||||
assert out["norm.weight"].shape == (8,)
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.convert_wan import sanitize_wan_vae_weights
|
||||
weights = {
|
||||
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
|
||||
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
|
||||
"decoder.norm.weight": mx.zeros((64,)),
|
||||
"decoder.bias": mx.zeros((16,)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
|
||||
sanitize_wan_vae_weights(weights)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wan2.1 Conversion Tests
|
||||
@@ -233,3 +281,27 @@ class TestSanitizeEncoderWeights:
|
||||
assert "encoder.conv1.weight" in out
|
||||
assert "conv1.weight" in out
|
||||
assert "conv2.weight" in out
|
||||
|
||||
def test_no_unconsumed_keys(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=True)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
|
||||
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
|
||||
sanitize_wan22_vae_weights(weights, include_encoder=False)
|
||||
assert "Unconsumed" not in caplog.text
|
||||
|
||||
@@ -291,3 +291,282 @@ class TestI2VMaskConstruction:
|
||||
encoded = mx.zeros((16, 5, 10, 18))
|
||||
y = mx.concatenate([mask, encoded], axis=0)
|
||||
assert y.shape == (20, 5, 10, 18)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: I2V end-to-end pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestI2VEndToEndPipeline:
|
||||
"""Full I2V pipeline: image → preprocess → VAE encode → y tensor → denoise → VAE decode."""
|
||||
|
||||
def test_full_i2v_pipeline(self):
|
||||
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
mx.random.seed(0)
|
||||
|
||||
# --- Tiny I2V model config (z_dim=16 to match VAE normalization stats) ---
|
||||
config = _make_tiny_i2v_config()
|
||||
config.vae_z_dim = 16
|
||||
config.out_dim = 16 # must match VAE z_dim for decode
|
||||
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
|
||||
model = WanModel(config)
|
||||
|
||||
# --- Tiny VAE (with encoder) ---
|
||||
vae = WanVAE(z_dim=config.vae_z_dim, encoder=True)
|
||||
|
||||
# --- Synthetic image: [B=1, 3, T=1, H=32, W=32] in [-1, 1] ---
|
||||
height, width = 32, 32
|
||||
num_frames = 5 # small temporal extent
|
||||
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
|
||||
video = mx.concatenate([
|
||||
img,
|
||||
mx.zeros((1, 3, num_frames - 1, height, width)),
|
||||
], axis=2)
|
||||
|
||||
# --- VAE encode ---
|
||||
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
|
||||
mx.eval(z_video)
|
||||
assert z_video.ndim == 5
|
||||
assert z_video.shape[1] == config.vae_z_dim
|
||||
|
||||
z_video = z_video[0] # [z_dim, T_lat, H_lat, W_lat]
|
||||
t_latent = z_video.shape[1]
|
||||
h_latent = z_video.shape[2]
|
||||
w_latent = z_video.shape[3]
|
||||
|
||||
# --- Build I2V mask (4 channels) ---
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
# --- Build y tensor: [mask(4ch) + encoded(z_dim ch)] ---
|
||||
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||
mx.eval(y_i2v)
|
||||
assert y_i2v.shape[0] == 4 + config.vae_z_dim
|
||||
|
||||
# --- Denoising loop (2 steps) ---
|
||||
C_noise = config.out_dim # noise channels
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (t_latent // pt) * (h_latent // ph) * (w_latent // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 2
|
||||
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||
|
||||
latents = mx.random.normal((C_noise, t_latent, h_latent, w_latent))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
for i in range(num_steps):
|
||||
t_val = sched.timesteps[i].item()
|
||||
pred = model(
|
||||
[latents],
|
||||
mx.array([t_val]),
|
||||
[context],
|
||||
seq_len,
|
||||
y=[y_i2v],
|
||||
)[0]
|
||||
latents = sched.step(pred[None], t_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C_noise, t_latent, h_latent, w_latent)
|
||||
assert not mx.any(mx.isnan(latents)).item(), "NaN in denoised latents"
|
||||
assert not mx.any(mx.isinf(latents)).item(), "Inf in denoised latents"
|
||||
|
||||
# --- VAE decode ---
|
||||
decoded = vae.decode(latents[None]) # [1, 3, T_out, H_out, W_out]
|
||||
mx.eval(decoded)
|
||||
assert decoded.ndim == 5
|
||||
assert decoded.shape[0] == 1
|
||||
assert decoded.shape[1] == 3 # RGB output
|
||||
assert not mx.any(mx.isnan(decoded)).item(), "NaN in decoded video"
|
||||
assert not mx.any(mx.isinf(decoded)).item(), "Inf in decoded video"
|
||||
# VAE decode clips to [-1, 1]
|
||||
assert float(decoded.max()) <= 1.0
|
||||
assert float(decoded.min()) >= -1.0
|
||||
|
||||
|
||||
class TestDualModelSwitching:
|
||||
"""Test dual-model selection logic: high_noise vs low_noise based on boundary."""
|
||||
|
||||
def test_model_selection_by_timestep(self):
|
||||
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(1)
|
||||
config = _make_tiny_i2v_config()
|
||||
assert config.dual_model is True
|
||||
|
||||
high_noise_model = WanModel(config)
|
||||
low_noise_model = WanModel(config)
|
||||
|
||||
boundary = config.boundary * config.num_train_timesteps # 0.9 * 1000 = 900
|
||||
|
||||
C_noise = config.out_dim # 4
|
||||
C_y = config.in_dim - config.out_dim # 9 - 4 = 5
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
num_steps = 5
|
||||
sched.set_timesteps(num_steps, shift=config.sample_shift)
|
||||
|
||||
guide_scale = config.sample_guide_scale # (3.5, 3.5)
|
||||
assert isinstance(guide_scale, tuple) and len(guide_scale) == 2
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
high_used_steps = []
|
||||
low_used_steps = []
|
||||
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
for i in range(num_steps):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
gs = guide_scale[1]
|
||||
high_used_steps.append(i)
|
||||
else:
|
||||
model = low_noise_model
|
||||
gs = guide_scale[0]
|
||||
low_used_steps.append(i)
|
||||
|
||||
# CFG pass: cond + uncond
|
||||
preds = model(
|
||||
[latents, latents],
|
||||
mx.array([timestep_val, timestep_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
# With shift=5.0, early timesteps should be high (>=900), later ones low
|
||||
assert len(high_used_steps) > 0, "High-noise model was never selected"
|
||||
assert len(low_used_steps) > 0, "Low-noise model was never selected"
|
||||
# High-noise steps should come before low-noise steps (timesteps decrease)
|
||||
if high_used_steps and low_used_steps:
|
||||
assert max(high_used_steps) < min(low_used_steps) or \
|
||||
min(high_used_steps) < max(low_used_steps), \
|
||||
"Model switching should happen during the loop"
|
||||
|
||||
assert latents.shape == (C_noise, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
|
||||
def test_guide_scale_tuple_applied_per_model(self):
|
||||
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(2)
|
||||
config = _make_tiny_i2v_config()
|
||||
config.sample_guide_scale = (2.0, 5.0) # distinct values
|
||||
|
||||
model = WanModel(config)
|
||||
boundary = config.boundary * config.num_train_timesteps
|
||||
|
||||
C_noise = config.out_dim
|
||||
F, H, W = 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(5, shift=config.sample_shift)
|
||||
|
||||
latents = mx.random.normal((C_noise, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
guide_scale = config.sample_guide_scale
|
||||
C_y = config.in_dim - config.out_dim # y channels
|
||||
y_i2v = mx.random.normal((C_y, F, H, W))
|
||||
|
||||
# Track which guide scale was used at each step
|
||||
gs_per_step = []
|
||||
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
for i in range(5):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
if timestep_val >= boundary:
|
||||
gs = guide_scale[1] # high_gs = 5.0
|
||||
else:
|
||||
gs = guide_scale[0] # low_gs = 2.0
|
||||
gs_per_step.append(gs)
|
||||
|
||||
pred = model(
|
||||
[latents, latents],
|
||||
mx.array([timestep_val, timestep_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
y=[y_i2v, y_i2v],
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
# Verify both guide scales were used
|
||||
assert 5.0 in gs_per_step, "High guide scale (5.0) was never used"
|
||||
assert 2.0 in gs_per_step, "Low guide scale (2.0) was never used"
|
||||
# High gs should appear first (high timesteps come first)
|
||||
first_high = gs_per_step.index(5.0)
|
||||
last_low = len(gs_per_step) - 1 - gs_per_step[::-1].index(2.0)
|
||||
assert first_high < last_low, "High gs steps should precede low gs steps"
|
||||
|
||||
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
|
||||
|
||||
mx.random.seed(3)
|
||||
config = _make_tiny_config()
|
||||
config.dual_model = False
|
||||
config.sample_guide_scale = (3.0, 5.0)
|
||||
|
||||
model = WanModel(config)
|
||||
guide_scale = config.sample_guide_scale
|
||||
|
||||
C, F, H, W = config.in_dim, 1, 4, 4
|
||||
pt, ph, pw = config.patch_size
|
||||
seq_len = (F // pt) * (H // ph) * (W // pw)
|
||||
|
||||
sched = FlowMatchEulerScheduler()
|
||||
sched.set_timesteps(3, shift=3.0)
|
||||
|
||||
latents = mx.random.normal((C, F, H, W))
|
||||
context = mx.random.normal((4, config.text_dim))
|
||||
|
||||
# Mimic generate_wan.py single-model logic:
|
||||
# gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
assert gs == 3.0, "Single model should use first element of guide_scale tuple"
|
||||
|
||||
for i in range(3):
|
||||
t_val = sched.timesteps[i].item()
|
||||
pred = model(
|
||||
[latents, latents],
|
||||
mx.array([t_val, t_val]),
|
||||
[context, context],
|
||||
seq_len,
|
||||
)
|
||||
noise_pred = pred[1] + gs * (pred[0] - pred[1])
|
||||
latents = sched.step(noise_pred[None], t_val, latents[None]).squeeze(0)
|
||||
mx.eval(latents)
|
||||
|
||||
assert latents.shape == (C, F, H, W)
|
||||
assert not mx.any(mx.isnan(latents)).item()
|
||||
|
||||
334
tests/test_wan_lora.py
Normal file
334
tests/test_wan_lora.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Tests for LoRA loading and application."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLoRATypes:
|
||||
"""Test LoRA data structures."""
|
||||
|
||||
def test_lora_weights_scale(self):
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.zeros((16, 64)),
|
||||
lora_B=mx.zeros((128, 16)),
|
||||
rank=16,
|
||||
alpha=32.0,
|
||||
module_name="test",
|
||||
)
|
||||
assert w.scale == 2.0
|
||||
|
||||
def test_lora_weights_scale_default(self):
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.zeros((16, 64)),
|
||||
lora_B=mx.zeros((128, 16)),
|
||||
rank=16,
|
||||
alpha=16.0,
|
||||
module_name="test",
|
||||
)
|
||||
assert w.scale == 1.0
|
||||
|
||||
def test_applied_lora_delta(self):
|
||||
from mlx_video.lora.types import AppliedLoRA, LoRAWeights
|
||||
|
||||
lora_a = mx.ones((2, 4))
|
||||
lora_b = mx.ones((8, 2))
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
applied = AppliedLoRA(weights=w, strength=0.5)
|
||||
delta = applied.compute_delta()
|
||||
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
||||
expected = 0.5 * mx.ones((8, 4)) * 2.0
|
||||
assert mx.allclose(delta, expected).item()
|
||||
|
||||
|
||||
class TestLoRALoader:
|
||||
"""Test LoRA weight loading from safetensors."""
|
||||
|
||||
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
|
||||
"""Helper to create a mock LoRA safetensors file."""
|
||||
weights = {}
|
||||
for name in module_names:
|
||||
if key_format == "AB":
|
||||
weights[f"{name}.lora_A.weight"] = mx.random.normal((rank, in_dim))
|
||||
weights[f"{name}.lora_B.weight"] = mx.random.normal((out_dim, rank))
|
||||
else:
|
||||
weights[f"{name}.lora_down.weight"] = mx.random.normal((rank, in_dim))
|
||||
weights[f"{name}.lora_up.weight"] = mx.random.normal((out_dim, rank))
|
||||
path = Path(tmp_dir) / "test_lora.safetensors"
|
||||
mx.save_safetensors(str(path), weights)
|
||||
return path
|
||||
|
||||
def test_load_lora_a_b_format(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(tmp, ["blocks.0.self_attn.q"], key_format="AB")
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert "blocks.0.self_attn.q" in lora_weights
|
||||
w = lora_weights["blocks.0.self_attn.q"]
|
||||
assert w.rank == 4
|
||||
assert w.alpha == 4.0 # default: alpha == rank
|
||||
assert w.lora_A.shape == (4, 64)
|
||||
assert w.lora_B.shape == (128, 4)
|
||||
|
||||
def test_load_lora_down_up_format(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(
|
||||
tmp, ["blocks.0.self_attn.q"], key_format="down_up"
|
||||
)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert "blocks.0.self_attn.q" in lora_weights
|
||||
|
||||
def test_load_multiple_modules(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
modules = [
|
||||
"blocks.0.self_attn.q",
|
||||
"blocks.0.self_attn.k",
|
||||
"blocks.0.ffn.fc1",
|
||||
]
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = self._make_lora_file(tmp, modules)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert len(lora_weights) == 3
|
||||
for name in modules:
|
||||
assert name in lora_weights
|
||||
|
||||
def test_load_with_alpha(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
weights = {
|
||||
"test.lora_A.weight": mx.random.normal((8, 64)),
|
||||
"test.lora_B.weight": mx.random.normal((128, 8)),
|
||||
"test.alpha": mx.array(16.0),
|
||||
}
|
||||
path = Path(tmp) / "lora.safetensors"
|
||||
mx.save_safetensors(str(path), weights)
|
||||
lora_weights = load_lora_weights(path)
|
||||
assert lora_weights["test"].alpha == 16.0
|
||||
assert lora_weights["test"].rank == 8
|
||||
assert lora_weights["test"].scale == 2.0
|
||||
|
||||
def test_file_not_found(self):
|
||||
from mlx_video.lora.loader import load_lora_weights
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_lora_weights(Path("/nonexistent/lora.safetensors"))
|
||||
|
||||
|
||||
class TestWanKeyNormalization:
|
||||
"""Test Wan2.2 LoRA key normalization."""
|
||||
|
||||
def _wan_model_keys(self):
|
||||
"""Simulate typical Wan2.2 MLX model weight keys."""
|
||||
keys = set()
|
||||
for i in range(2):
|
||||
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
|
||||
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
|
||||
keys.add(f"blocks.{i}.{layer}.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
||||
keys.add("text_embedding_0.weight")
|
||||
keys.add("text_embedding_1.weight")
|
||||
keys.add("time_embedding_0.weight")
|
||||
keys.add("time_embedding_1.weight")
|
||||
keys.add("time_projection.weight")
|
||||
keys.add("patch_embedding_proj.weight")
|
||||
return keys
|
||||
|
||||
def test_direct_match(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
|
||||
|
||||
def test_strip_diffusion_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("diffusion_model.blocks.0.self_attn.q", keys)
|
||||
assert result == "blocks.0.self_attn.q"
|
||||
|
||||
def test_strip_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
|
||||
assert result == "blocks.0.self_attn.k"
|
||||
|
||||
def test_ffn_key_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("blocks.0.ffn.0", keys) == "blocks.0.ffn.fc1"
|
||||
assert _normalize_wan_lora_key("blocks.0.ffn.2", keys) == "blocks.0.ffn.fc2"
|
||||
|
||||
def test_text_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("text_embedding.0", keys) == "text_embedding_0"
|
||||
assert _normalize_wan_lora_key("text_embedding.2", keys) == "text_embedding_1"
|
||||
|
||||
def test_time_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("time_embedding.0", keys) == "time_embedding_0"
|
||||
assert _normalize_wan_lora_key("time_embedding.2", keys) == "time_embedding_1"
|
||||
|
||||
def test_time_projection_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("time_projection.1", keys) == "time_projection"
|
||||
|
||||
def test_patch_embedding_mapping(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||
|
||||
def test_combined_prefix_and_ffn(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("diffusion_model.blocks.1.ffn.0", keys)
|
||||
assert result == "blocks.1.ffn.fc1"
|
||||
|
||||
|
||||
class TestApplyLoRA:
|
||||
"""Test LoRA delta application to weights."""
|
||||
|
||||
def test_preserves_bfloat16_dtype(self):
|
||||
"""LoRA delta must not promote bfloat16 weights to float32."""
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4), dtype=mx.bfloat16)
|
||||
# LoRA weights in float32 (typical when loaded from safetensors)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
||||
|
||||
def test_preserves_float16_dtype(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4), dtype=mx.float16)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
||||
|
||||
def test_apply_single_lora(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.ones((8, 4))
|
||||
lora_a = mx.ones((2, 4)) * 0.1
|
||||
lora_b = mx.ones((8, 2)) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
||||
expected = original + 0.02 * mx.ones((8, 4))
|
||||
assert mx.allclose(result, expected, atol=1e-6).item()
|
||||
|
||||
def test_apply_multiple_loras(self):
|
||||
from mlx_video.lora.apply import apply_lora_to_linear
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
original = mx.zeros((8, 4))
|
||||
w1 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)),
|
||||
lora_B=mx.ones((8, 2)),
|
||||
rank=2, alpha=2.0, module_name="a",
|
||||
)
|
||||
w2 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)) * 2,
|
||||
lora_B=mx.ones((8, 2)) * 2,
|
||||
rank=2, alpha=4.0, module_name="b",
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
||||
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
||||
# w2 delta: 2.0 * 0.5 * (2*ones(8,2) @ 2*ones(2,4)) = 1.0 * 8*ones(8,4) = 8
|
||||
delta1 = mx.ones((8, 4)) * 2.0
|
||||
delta2 = mx.ones((8, 4)) * 8.0
|
||||
expected = delta1 + delta2
|
||||
assert mx.allclose(result, expected, atol=1e-5).item()
|
||||
|
||||
def test_apply_loras_to_weights_dict(self):
|
||||
from mlx_video.lora.apply import apply_loras_to_weights
|
||||
from mlx_video.lora.types import LoRAWeights
|
||||
|
||||
model_weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
"blocks.0.ffn.fc1.weight": mx.ones((256, 64)),
|
||||
}
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.ones((4, 64)) * 0.01,
|
||||
lora_B=mx.ones((128, 4)) * 0.01,
|
||||
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
|
||||
)
|
||||
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
||||
result = apply_loras_to_weights(model_weights, module_to_loras)
|
||||
# Only q should be modified
|
||||
assert not mx.array_equal(
|
||||
result["blocks.0.self_attn.q.weight"],
|
||||
model_weights["blocks.0.self_attn.q.weight"],
|
||||
).item()
|
||||
assert mx.array_equal(
|
||||
result["blocks.0.self_attn.k.weight"],
|
||||
model_weights["blocks.0.self_attn.k.weight"],
|
||||
).item()
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
"""End-to-end LoRA loading and application."""
|
||||
|
||||
def test_load_and_apply_loras(self):
|
||||
from mlx_video.convert_wan import load_and_apply_loras
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
# Create mock LoRA safetensors
|
||||
rank = 4
|
||||
weights = {
|
||||
"blocks.0.self_attn.q.lora_A.weight": mx.random.normal((rank, 64)),
|
||||
"blocks.0.self_attn.q.lora_B.weight": mx.random.normal((128, rank)),
|
||||
}
|
||||
lora_path = Path(tmp) / "test.safetensors"
|
||||
mx.save_safetensors(str(lora_path), weights)
|
||||
|
||||
# Create mock model weights
|
||||
model_weights = {
|
||||
"blocks.0.self_attn.q.weight": mx.ones((128, 64)),
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
}
|
||||
|
||||
result = load_and_apply_loras(
|
||||
model_weights, [(str(lora_path), 1.0)]
|
||||
)
|
||||
|
||||
# q weight should be modified, k unchanged
|
||||
assert not mx.array_equal(
|
||||
result["blocks.0.self_attn.q.weight"],
|
||||
model_weights["blocks.0.self_attn.q.weight"],
|
||||
).item()
|
||||
assert mx.array_equal(
|
||||
result["blocks.0.self_attn.k.weight"],
|
||||
model_weights["blocks.0.self_attn.k.weight"],
|
||||
).item()
|
||||
@@ -868,4 +868,84 @@ class TestVAEEncoderTemporalOrder:
|
||||
assert out_wrong.shape[1] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VAE Encode → Decode Round-Trip Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVAE21RoundTrip:
|
||||
"""Encode→decode round-trip for Wan 2.1 VAE (channels-first)."""
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
|
||||
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
|
||||
|
||||
z_dim = 4
|
||||
dim = 8
|
||||
# No temporal up/downsampling to keep the test simple
|
||||
enc = Encoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_downsample=[False, False, False]
|
||||
)
|
||||
dec = Decoder3d(
|
||||
dim=dim, z_dim=z_dim, temporal_upsample=[False, False, False]
|
||||
)
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, C=3, T=1, H=8, W=8]
|
||||
x = mx.random.normal((1, 3, 1, 8, 8)) * 0.5
|
||||
|
||||
z = enc(x)
|
||||
mx.eval(z)
|
||||
# 3 spatial downsamples (÷8): H=1, W=1
|
||||
assert z.shape == (1, z_dim, 1, 1, 1)
|
||||
|
||||
x_hat = dec(z)
|
||||
mx.eval(x_hat)
|
||||
# 3 spatial upsamples (×8): should recover original shape
|
||||
assert x_hat.shape == x.shape
|
||||
|
||||
out_np = np.array(x_hat)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert np.abs(out_np).max() < 1000
|
||||
|
||||
|
||||
class TestVAE22RoundTrip:
|
||||
"""Encode→decode round-trip for Wan 2.2 VAE (channels-last)."""
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
|
||||
from mlx_video.models.wan.vae22 import (
|
||||
Wan22VAEDecoder,
|
||||
Wan22VAEEncoder,
|
||||
denormalize_latents,
|
||||
)
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
dec = Wan22VAEDecoder(z_dim=48, dec_dim=8)
|
||||
mx.eval(enc.parameters(), dec.parameters())
|
||||
|
||||
# [B=1, T=1, H=32, W=32, C=3]
|
||||
img = mx.random.normal((1, 1, 32, 32, 3)) * 0.5
|
||||
|
||||
z_norm = enc(img)
|
||||
mx.eval(z_norm)
|
||||
# patchify(÷2) + 3 spatial downsamples(÷8) = ÷16
|
||||
assert z_norm.shape == (1, 1, 2, 2, 48)
|
||||
|
||||
z = denormalize_latents(z_norm)
|
||||
out = dec(z)
|
||||
mx.eval(out)
|
||||
|
||||
# 3 spatial upsamples(×8) + unpatchify(×2) = ×16
|
||||
assert out.shape[0] == 1 # batch
|
||||
assert out.shape[2] == 32 # H recovered
|
||||
assert out.shape[3] == 32 # W recovered
|
||||
assert out.shape[-1] == 3 # RGB
|
||||
|
||||
out_np = np.array(out)
|
||||
assert np.all(np.isfinite(out_np))
|
||||
assert out_np.min() >= -1.0 - 1e-6
|
||||
assert out_np.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user