format
This commit is contained in:
@@ -4,7 +4,6 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -40,7 +39,9 @@ class TestLoRATypes:
|
||||
|
||||
lora_a = mx.ones((2, 4))
|
||||
lora_b = mx.ones((8, 2))
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
applied = AppliedLoRA(weights=w, strength=0.5)
|
||||
delta = applied.compute_delta()
|
||||
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
|
||||
@@ -51,7 +52,9 @@ class TestLoRATypes:
|
||||
class TestLoRALoader:
|
||||
"""Test LoRA weight loading from safetensors."""
|
||||
|
||||
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
|
||||
def _make_lora_file(
|
||||
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
|
||||
):
|
||||
"""Helper to create a mock LoRA safetensors file."""
|
||||
weights = {}
|
||||
for name in module_names:
|
||||
@@ -133,8 +136,16 @@ class TestWanKeyNormalization:
|
||||
"""Simulate typical Wan2.2 MLX model weight keys."""
|
||||
keys = set()
|
||||
for i in range(2):
|
||||
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
|
||||
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
|
||||
for layer in [
|
||||
"self_attn.q",
|
||||
"self_attn.k",
|
||||
"self_attn.v",
|
||||
"self_attn.o",
|
||||
"cross_attn.q",
|
||||
"cross_attn.k",
|
||||
"cross_attn.v",
|
||||
"cross_attn.o",
|
||||
]:
|
||||
keys.add(f"blocks.{i}.{layer}.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc1.weight")
|
||||
keys.add(f"blocks.{i}.ffn.fc2.weight")
|
||||
@@ -150,7 +161,10 @@ class TestWanKeyNormalization:
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
|
||||
assert (
|
||||
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
|
||||
== "blocks.0.self_attn.q"
|
||||
)
|
||||
|
||||
def test_strip_diffusion_model_prefix(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
@@ -163,7 +177,9 @@ class TestWanKeyNormalization:
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
|
||||
result = _normalize_wan_lora_key(
|
||||
"model.diffusion_model.blocks.0.self_attn.k", keys
|
||||
)
|
||||
assert result == "blocks.0.self_attn.k"
|
||||
|
||||
def test_ffn_key_mapping(self):
|
||||
@@ -197,7 +213,9 @@ class TestWanKeyNormalization:
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
|
||||
keys = self._wan_model_keys()
|
||||
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||
assert (
|
||||
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
|
||||
)
|
||||
|
||||
def test_combined_prefix_and_ffn(self):
|
||||
from mlx_video.lora.apply import _normalize_wan_lora_key
|
||||
@@ -219,7 +237,9 @@ class TestApplyLoRA:
|
||||
# LoRA weights in float32 (typical when loaded from safetensors)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
|
||||
|
||||
@@ -230,7 +250,9 @@ class TestApplyLoRA:
|
||||
original = mx.ones((8, 4), dtype=mx.float16)
|
||||
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
|
||||
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
|
||||
|
||||
@@ -241,7 +263,9 @@ class TestApplyLoRA:
|
||||
original = mx.ones((8, 4))
|
||||
lora_a = mx.ones((2, 4)) * 0.1
|
||||
lora_b = mx.ones((8, 2)) * 0.1
|
||||
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
|
||||
w = LoRAWeights(
|
||||
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w, 1.0)])
|
||||
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
|
||||
expected = original + 0.02 * mx.ones((8, 4))
|
||||
@@ -255,12 +279,16 @@ class TestApplyLoRA:
|
||||
w1 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)),
|
||||
lora_B=mx.ones((8, 2)),
|
||||
rank=2, alpha=2.0, module_name="a",
|
||||
rank=2,
|
||||
alpha=2.0,
|
||||
module_name="a",
|
||||
)
|
||||
w2 = LoRAWeights(
|
||||
lora_A=mx.ones((2, 4)) * 2,
|
||||
lora_B=mx.ones((8, 2)) * 2,
|
||||
rank=2, alpha=4.0, module_name="b",
|
||||
rank=2,
|
||||
alpha=4.0,
|
||||
module_name="b",
|
||||
)
|
||||
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
|
||||
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
|
||||
@@ -282,7 +310,9 @@ class TestApplyLoRA:
|
||||
w = LoRAWeights(
|
||||
lora_A=mx.ones((4, 64)) * 0.01,
|
||||
lora_B=mx.ones((128, 4)) * 0.01,
|
||||
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
|
||||
rank=4,
|
||||
alpha=4.0,
|
||||
module_name="blocks.0.self_attn.q",
|
||||
)
|
||||
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
|
||||
result = apply_loras_to_weights(model_weights, module_to_loras)
|
||||
@@ -319,9 +349,7 @@ class TestEndToEnd:
|
||||
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
|
||||
}
|
||||
|
||||
result = load_and_apply_loras(
|
||||
model_weights, [(str(lora_path), 1.0)]
|
||||
)
|
||||
result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
|
||||
|
||||
# q weight should be modified, k unchanged
|
||||
assert not mx.array_equal(
|
||||
|
||||
Reference in New Issue
Block a user