optimize memory usage by batching weight updates

This commit is contained in:
Prince Canuma
2026-03-15 03:12:47 +01:00
parent 53bae534e7
commit ebcd5dd4e4

View File

@@ -163,7 +163,7 @@ def load_and_merge_lora(
else: else:
sanitized_pairs = lora_pairs sanitized_pairs = lora_pairs
# Get current model weights as a flat dict # Get current model weights as a flat dict (references, not copies)
def flatten_params(params, prefix=""): def flatten_params(params, prefix=""):
flat = {} flat = {}
for k, v in params.items(): for k, v in params.items():
@@ -176,9 +176,11 @@ def load_and_merge_lora(
flat_weights = flatten_params(dict(model.parameters())) flat_weights = flatten_params(dict(model.parameters()))
# Merge LoRA deltas # Merge LoRA deltas in batches to avoid doubling memory
merged_count = 0 merged_count = 0
updates = [] batch = []
batch_size = 100 # merge 100 weights at a time, then eval to free intermediates
for module_key, pair in sanitized_pairs.items(): for module_key, pair in sanitized_pairs.items():
if "A" not in pair or "B" not in pair: if "A" not in pair or "B" not in pair:
continue continue
@@ -193,13 +195,24 @@ def load_and_merge_lora(
# delta = (lora_B * strength) @ lora_A # delta = (lora_B * strength) @ lora_A
delta = (lora_b * strength) @ lora_a delta = (lora_b * strength) @ lora_a
base_weight = flat_weights[weight_key].astype(mx.float32) base_weight = flat_weights.pop(weight_key)
merged_weight = base_weight + delta merged_weight = (base_weight.astype(mx.float32) + delta).astype(base_weight.dtype)
updates.append((weight_key, merged_weight.astype(mx.bfloat16))) batch.append((weight_key, merged_weight))
del base_weight
merged_count += 1 merged_count += 1
model.load_weights(updates, strict=False) if len(batch) >= batch_size:
model.load_weights(batch, strict=False)
mx.eval(model.parameters()) mx.eval(model.parameters())
batch.clear()
if batch:
model.load_weights(batch, strict=False)
mx.eval(model.parameters())
batch.clear()
del flat_weights, lora_weights
mx.clear_cache()
console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})") console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")