optimize memory usage by batching weight updates

This commit is contained in:
Prince Canuma
2026-03-15 03:12:47 +01:00
parent 53bae534e7
commit ebcd5dd4e4

View File

@@ -163,7 +163,7 @@ def load_and_merge_lora(
else:
sanitized_pairs = lora_pairs
# Get current model weights as a flat dict
# Get current model weights as a flat dict (references, not copies)
def flatten_params(params, prefix=""):
flat = {}
for k, v in params.items():
@@ -176,9 +176,11 @@ def load_and_merge_lora(
flat_weights = flatten_params(dict(model.parameters()))
# Merge LoRA deltas
# Merge LoRA deltas in batches to avoid doubling memory
merged_count = 0
updates = []
batch = []
batch_size = 100 # merge 100 weights at a time, then eval to free intermediates
for module_key, pair in sanitized_pairs.items():
if "A" not in pair or "B" not in pair:
continue
@@ -193,13 +195,24 @@ def load_and_merge_lora(
# delta = (lora_B * strength) @ lora_A
delta = (lora_b * strength) @ lora_a
base_weight = flat_weights[weight_key].astype(mx.float32)
merged_weight = base_weight + delta
updates.append((weight_key, merged_weight.astype(mx.bfloat16)))
base_weight = flat_weights.pop(weight_key)
merged_weight = (base_weight.astype(mx.float32) + delta).astype(base_weight.dtype)
batch.append((weight_key, merged_weight))
del base_weight
merged_count += 1
model.load_weights(updates, strict=False)
mx.eval(model.parameters())
if len(batch) >= batch_size:
model.load_weights(batch, strict=False)
mx.eval(model.parameters())
batch.clear()
if batch:
model.load_weights(batch, strict=False)
mx.eval(model.parameters())
batch.clear()
del flat_weights, lora_weights
mx.clear_cache()
console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")