This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -4,7 +4,6 @@ import tempfile
from pathlib import Path
import mlx.core as mx
import numpy as np
import pytest
@@ -40,7 +39,9 @@ class TestLoRATypes:
lora_a = mx.ones((2, 4))
lora_b = mx.ones((8, 2))
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
applied = AppliedLoRA(weights=w, strength=0.5)
delta = applied.compute_delta()
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
@@ -51,7 +52,9 @@ class TestLoRATypes:
class TestLoRALoader:
"""Test LoRA weight loading from safetensors."""
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
def _make_lora_file(
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
):
"""Helper to create a mock LoRA safetensors file."""
weights = {}
for name in module_names:
@@ -133,8 +136,16 @@ class TestWanKeyNormalization:
"""Simulate typical Wan2.2 MLX model weight keys."""
keys = set()
for i in range(2):
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
for layer in [
"self_attn.q",
"self_attn.k",
"self_attn.v",
"self_attn.o",
"cross_attn.q",
"cross_attn.k",
"cross_attn.v",
"cross_attn.o",
]:
keys.add(f"blocks.{i}.{layer}.weight")
keys.add(f"blocks.{i}.ffn.fc1.weight")
keys.add(f"blocks.{i}.ffn.fc2.weight")
@@ -150,7 +161,10 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
assert (
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
== "blocks.0.self_attn.q"
)
def test_strip_diffusion_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -163,7 +177,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
result = _normalize_wan_lora_key(
"model.diffusion_model.blocks.0.self_attn.k", keys
)
assert result == "blocks.0.self_attn.k"
def test_ffn_key_mapping(self):
@@ -197,7 +213,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
assert (
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
)
def test_combined_prefix_and_ffn(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -219,7 +237,9 @@ class TestApplyLoRA:
# LoRA weights in float32 (typical when loaded from safetensors)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
@@ -230,7 +250,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4), dtype=mx.float16)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
@@ -241,7 +263,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4))
lora_a = mx.ones((2, 4)) * 0.1
lora_b = mx.ones((8, 2)) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
expected = original + 0.02 * mx.ones((8, 4))
@@ -255,12 +279,16 @@ class TestApplyLoRA:
w1 = LoRAWeights(
lora_A=mx.ones((2, 4)),
lora_B=mx.ones((8, 2)),
rank=2, alpha=2.0, module_name="a",
rank=2,
alpha=2.0,
module_name="a",
)
w2 = LoRAWeights(
lora_A=mx.ones((2, 4)) * 2,
lora_B=mx.ones((8, 2)) * 2,
rank=2, alpha=4.0, module_name="b",
rank=2,
alpha=4.0,
module_name="b",
)
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
@@ -282,7 +310,9 @@ class TestApplyLoRA:
w = LoRAWeights(
lora_A=mx.ones((4, 64)) * 0.01,
lora_B=mx.ones((128, 4)) * 0.01,
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
rank=4,
alpha=4.0,
module_name="blocks.0.self_attn.q",
)
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
result = apply_loras_to_weights(model_weights, module_to_loras)
@@ -319,9 +349,7 @@ class TestEndToEnd:
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
}
result = load_and_apply_loras(
model_weights, [(str(lora_path), 1.0)]
)
result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
# q weight should be modified, k unchanged
assert not mx.array_equal(