optimize memory usage by batching weight updates
This commit is contained in:
@@ -163,7 +163,7 @@ def load_and_merge_lora(
|
||||
else:
|
||||
sanitized_pairs = lora_pairs
|
||||
|
||||
# Get current model weights as a flat dict
|
||||
# Get current model weights as a flat dict (references, not copies)
|
||||
def flatten_params(params, prefix=""):
|
||||
flat = {}
|
||||
for k, v in params.items():
|
||||
@@ -176,9 +176,11 @@ def load_and_merge_lora(
|
||||
|
||||
flat_weights = flatten_params(dict(model.parameters()))
|
||||
|
||||
# Merge LoRA deltas
|
||||
# Merge LoRA deltas in batches to avoid doubling memory
|
||||
merged_count = 0
|
||||
updates = []
|
||||
batch = []
|
||||
batch_size = 100 # merge 100 weights at a time, then eval to free intermediates
|
||||
|
||||
for module_key, pair in sanitized_pairs.items():
|
||||
if "A" not in pair or "B" not in pair:
|
||||
continue
|
||||
@@ -193,13 +195,24 @@ def load_and_merge_lora(
|
||||
# delta = (lora_B * strength) @ lora_A
|
||||
delta = (lora_b * strength) @ lora_a
|
||||
|
||||
base_weight = flat_weights[weight_key].astype(mx.float32)
|
||||
merged_weight = base_weight + delta
|
||||
updates.append((weight_key, merged_weight.astype(mx.bfloat16)))
|
||||
base_weight = flat_weights.pop(weight_key)
|
||||
merged_weight = (base_weight.astype(mx.float32) + delta).astype(base_weight.dtype)
|
||||
batch.append((weight_key, merged_weight))
|
||||
del base_weight
|
||||
merged_count += 1
|
||||
|
||||
model.load_weights(updates, strict=False)
|
||||
mx.eval(model.parameters())
|
||||
if len(batch) >= batch_size:
|
||||
model.load_weights(batch, strict=False)
|
||||
mx.eval(model.parameters())
|
||||
batch.clear()
|
||||
|
||||
if batch:
|
||||
model.load_weights(batch, strict=False)
|
||||
mx.eval(model.parameters())
|
||||
batch.clear()
|
||||
|
||||
del flat_weights, lora_weights
|
||||
mx.clear_cache()
|
||||
console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user