format
This commit is contained in:
@@ -66,7 +66,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
candidates = [lora_key]
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
candidates.append(lora_key[len(prefix):])
|
||||
candidates.append(lora_key[len(prefix) :])
|
||||
|
||||
for candidate in candidates:
|
||||
# Try as-is
|
||||
@@ -80,33 +80,36 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
|
||||
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
|
||||
if transformed.endswith(".ffn.0"):
|
||||
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
|
||||
transformed = transformed[: -len(".ffn.0")] + ".ffn.fc1"
|
||||
if transformed.endswith(".ffn.2"):
|
||||
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
|
||||
transformed = transformed[: -len(".ffn.2")] + ".ffn.fc2"
|
||||
|
||||
# Text embedding: text_embedding.0 → text_embedding_0
|
||||
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
|
||||
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
|
||||
if transformed.endswith("text_embedding.0"):
|
||||
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
|
||||
transformed = transformed[: -len("text_embedding.0")] + "text_embedding_0"
|
||||
if transformed.endswith("text_embedding.2"):
|
||||
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
|
||||
transformed = transformed[: -len("text_embedding.2")] + "text_embedding_1"
|
||||
|
||||
# Time embedding: time_embedding.0 → time_embedding_0
|
||||
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
|
||||
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
|
||||
if transformed.endswith("time_embedding.0"):
|
||||
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
|
||||
transformed = transformed[: -len("time_embedding.0")] + "time_embedding_0"
|
||||
if transformed.endswith("time_embedding.2"):
|
||||
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
|
||||
transformed = transformed[: -len("time_embedding.2")] + "time_embedding_1"
|
||||
|
||||
# Time projection: time_projection.1 → time_projection
|
||||
transformed = transformed.replace("time_projection.1.", "time_projection.")
|
||||
if transformed.endswith("time_projection.1"):
|
||||
transformed = transformed[:-len("time_projection.1")] + "time_projection"
|
||||
transformed = transformed[: -len("time_projection.1")] + "time_projection"
|
||||
|
||||
# Patch embedding: patch_embedding → patch_embedding_proj
|
||||
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
|
||||
if (
|
||||
"patch_embedding" in transformed
|
||||
and "patch_embedding_proj" not in transformed
|
||||
):
|
||||
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
|
||||
|
||||
if f"{transformed}.weight" in model_keys or transformed in model_keys:
|
||||
@@ -115,7 +118,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
# Return best attempt with prefix stripped
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
return lora_key[len(prefix):]
|
||||
return lora_key[len(prefix) :]
|
||||
|
||||
return lora_key
|
||||
|
||||
@@ -134,21 +137,25 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
normalized = lora_key[len(prefix):]
|
||||
normalized = lora_key[len(prefix) :]
|
||||
|
||||
if f"{normalized}.weight" in model_keys or normalized in model_keys:
|
||||
return normalized
|
||||
|
||||
transformed = normalized
|
||||
if transformed.endswith(".to_out.0"):
|
||||
transformed = transformed[:-len(".to_out.0")] + ".to_out"
|
||||
transformed = transformed[: -len(".to_out.0")] + ".to_out"
|
||||
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
|
||||
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
|
||||
transformed = transformed.replace(
|
||||
".audio_ff.net.0.proj.", ".audio_ff.proj_in."
|
||||
)
|
||||
transformed = transformed.replace(
|
||||
".audio_ff.net.0.proj", ".audio_ff.proj_in"
|
||||
)
|
||||
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
|
||||
|
||||
@@ -158,7 +165,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
# Try transformations on the original key
|
||||
transformed = lora_key
|
||||
if transformed.endswith(".to_out.0"):
|
||||
transformed = transformed[:-len(".to_out.0")] + ".to_out"
|
||||
transformed = transformed[: -len(".to_out.0")] + ".to_out"
|
||||
transformed = transformed.replace(".to_out.0.", ".to_out.")
|
||||
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
|
||||
@@ -170,7 +177,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
|
||||
|
||||
for prefix in prefixes_to_strip:
|
||||
if lora_key.startswith(prefix):
|
||||
return lora_key[len(prefix):]
|
||||
return lora_key[len(prefix) :]
|
||||
|
||||
return lora_key
|
||||
|
||||
@@ -226,7 +233,9 @@ def apply_loras_to_weights(
|
||||
skipped_count += 1
|
||||
skipped_modules.append(module_name)
|
||||
if verbose and skipped_count <= 5:
|
||||
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
|
||||
print(
|
||||
f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND"
|
||||
)
|
||||
similar = [
|
||||
k
|
||||
for k in list(model_keys)[:1000]
|
||||
@@ -251,13 +260,21 @@ def apply_loras_to_weights(
|
||||
if is_quantized:
|
||||
scales = modified_weights[scales_key]
|
||||
biases = modified_weights[biases_key]
|
||||
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
|
||||
group_size = (original_weight.shape[-1] * 32) // (
|
||||
scales.shape[-1] * quantization_bits
|
||||
)
|
||||
dequantized = mx.dequantize(
|
||||
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
|
||||
original_weight,
|
||||
scales,
|
||||
biases,
|
||||
group_size=group_size,
|
||||
bits=quantization_bits,
|
||||
)
|
||||
modified = apply_lora_to_linear(dequantized, loras)
|
||||
# Re-quantize with same parameters
|
||||
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
|
||||
new_w, new_scales, new_biases = mx.quantize(
|
||||
modified, group_size=group_size, bits=quantization_bits
|
||||
)
|
||||
modified_weights[weight_key] = new_w
|
||||
modified_weights[scales_key] = new_scales
|
||||
modified_weights[biases_key] = new_biases
|
||||
@@ -346,9 +363,15 @@ def apply_loras_to_model(
|
||||
parent = model
|
||||
try:
|
||||
for part in parts[:-1]:
|
||||
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
|
||||
parent = (
|
||||
getattr(parent, part) if not part.isdigit() else parent[int(part)]
|
||||
)
|
||||
leaf_name = parts[-1]
|
||||
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
|
||||
target = (
|
||||
getattr(parent, leaf_name)
|
||||
if not leaf_name.isdigit()
|
||||
else parent[int(leaf_name)]
|
||||
)
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
skipped.append(lora_key)
|
||||
if verbose:
|
||||
@@ -358,8 +381,11 @@ def apply_loras_to_model(
|
||||
if isinstance(target, nn.QuantizedLinear):
|
||||
# Dequantize → merge LoRA → replace with bf16 Linear
|
||||
weight = mx.dequantize(
|
||||
target.weight, target.scales, target.biases,
|
||||
group_size=target.group_size, bits=target.bits,
|
||||
target.weight,
|
||||
target.scales,
|
||||
target.biases,
|
||||
group_size=target.group_size,
|
||||
bits=target.bits,
|
||||
)
|
||||
merged = apply_lora_to_linear(weight, loras)
|
||||
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
|
||||
@@ -379,7 +405,9 @@ def apply_loras_to_model(
|
||||
else:
|
||||
skipped.append(lora_key)
|
||||
if verbose:
|
||||
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
|
||||
print(
|
||||
f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear"
|
||||
)
|
||||
continue
|
||||
|
||||
if applied_count > 0:
|
||||
|
||||
Reference in New Issue
Block a user