This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -66,7 +66,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
candidates = [lora_key]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
candidates.append(lora_key[len(prefix):])
candidates.append(lora_key[len(prefix) :])
for candidate in candidates:
# Try as-is
@@ -80,33 +80,36 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
if transformed.endswith(".ffn.0"):
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
transformed = transformed[: -len(".ffn.0")] + ".ffn.fc1"
if transformed.endswith(".ffn.2"):
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
transformed = transformed[: -len(".ffn.2")] + ".ffn.fc2"
# Text embedding: text_embedding.0 → text_embedding_0
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
if transformed.endswith("text_embedding.0"):
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
transformed = transformed[: -len("text_embedding.0")] + "text_embedding_0"
if transformed.endswith("text_embedding.2"):
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
transformed = transformed[: -len("text_embedding.2")] + "text_embedding_1"
# Time embedding: time_embedding.0 → time_embedding_0
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
if transformed.endswith("time_embedding.0"):
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
transformed = transformed[: -len("time_embedding.0")] + "time_embedding_0"
if transformed.endswith("time_embedding.2"):
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
transformed = transformed[: -len("time_embedding.2")] + "time_embedding_1"
# Time projection: time_projection.1 → time_projection
transformed = transformed.replace("time_projection.1.", "time_projection.")
if transformed.endswith("time_projection.1"):
transformed = transformed[:-len("time_projection.1")] + "time_projection"
transformed = transformed[: -len("time_projection.1")] + "time_projection"
# Patch embedding: patch_embedding → patch_embedding_proj
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
if (
"patch_embedding" in transformed
and "patch_embedding_proj" not in transformed
):
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
@@ -115,7 +118,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
# Return best attempt with prefix stripped
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key[len(prefix) :]
return lora_key
@@ -134,21 +137,25 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
normalized = lora_key[len(prefix):]
normalized = lora_key[len(prefix) :]
if f"{normalized}.weight" in model_keys or normalized in model_keys:
return normalized
transformed = normalized
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed[: -len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
transformed = transformed.replace(
".audio_ff.net.0.proj.", ".audio_ff.proj_in."
)
transformed = transformed.replace(
".audio_ff.net.0.proj", ".audio_ff.proj_in"
)
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
@@ -158,7 +165,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
# Try transformations on the original key
transformed = lora_key
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed[: -len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
@@ -170,7 +177,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key[len(prefix) :]
return lora_key
@@ -226,7 +233,9 @@ def apply_loras_to_weights(
skipped_count += 1
skipped_modules.append(module_name)
if verbose and skipped_count <= 5:
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
print(
f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND"
)
similar = [
k
for k in list(model_keys)[:1000]
@@ -251,13 +260,21 @@ def apply_loras_to_weights(
if is_quantized:
scales = modified_weights[scales_key]
biases = modified_weights[biases_key]
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
group_size = (original_weight.shape[-1] * 32) // (
scales.shape[-1] * quantization_bits
)
dequantized = mx.dequantize(
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
original_weight,
scales,
biases,
group_size=group_size,
bits=quantization_bits,
)
modified = apply_lora_to_linear(dequantized, loras)
# Re-quantize with same parameters
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
new_w, new_scales, new_biases = mx.quantize(
modified, group_size=group_size, bits=quantization_bits
)
modified_weights[weight_key] = new_w
modified_weights[scales_key] = new_scales
modified_weights[biases_key] = new_biases
@@ -346,9 +363,15 @@ def apply_loras_to_model(
parent = model
try:
for part in parts[:-1]:
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
parent = (
getattr(parent, part) if not part.isdigit() else parent[int(part)]
)
leaf_name = parts[-1]
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
target = (
getattr(parent, leaf_name)
if not leaf_name.isdigit()
else parent[int(leaf_name)]
)
except (AttributeError, IndexError, TypeError):
skipped.append(lora_key)
if verbose:
@@ -358,8 +381,11 @@ def apply_loras_to_model(
if isinstance(target, nn.QuantizedLinear):
# Dequantize → merge LoRA → replace with bf16 Linear
weight = mx.dequantize(
target.weight, target.scales, target.biases,
group_size=target.group_size, bits=target.bits,
target.weight,
target.scales,
target.biases,
group_size=target.group_size,
bits=target.bits,
)
merged = apply_lora_to_linear(weight, loras)
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
@@ -379,7 +405,9 @@ def apply_loras_to_model(
else:
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
print(
f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear"
)
continue
if applied_count > 0: