optimize memory usage by batching weight updates
This commit is contained in:
@@ -163,7 +163,7 @@ def load_and_merge_lora(
|
|||||||
else:
|
else:
|
||||||
sanitized_pairs = lora_pairs
|
sanitized_pairs = lora_pairs
|
||||||
|
|
||||||
# Get current model weights as a flat dict
|
# Get current model weights as a flat dict (references, not copies)
|
||||||
def flatten_params(params, prefix=""):
|
def flatten_params(params, prefix=""):
|
||||||
flat = {}
|
flat = {}
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
@@ -176,9 +176,11 @@ def load_and_merge_lora(
|
|||||||
|
|
||||||
flat_weights = flatten_params(dict(model.parameters()))
|
flat_weights = flatten_params(dict(model.parameters()))
|
||||||
|
|
||||||
# Merge LoRA deltas
|
# Merge LoRA deltas in batches to avoid doubling memory
|
||||||
merged_count = 0
|
merged_count = 0
|
||||||
updates = []
|
batch = []
|
||||||
|
batch_size = 100 # merge 100 weights at a time, then eval to free intermediates
|
||||||
|
|
||||||
for module_key, pair in sanitized_pairs.items():
|
for module_key, pair in sanitized_pairs.items():
|
||||||
if "A" not in pair or "B" not in pair:
|
if "A" not in pair or "B" not in pair:
|
||||||
continue
|
continue
|
||||||
@@ -193,13 +195,24 @@ def load_and_merge_lora(
|
|||||||
# delta = (lora_B * strength) @ lora_A
|
# delta = (lora_B * strength) @ lora_A
|
||||||
delta = (lora_b * strength) @ lora_a
|
delta = (lora_b * strength) @ lora_a
|
||||||
|
|
||||||
base_weight = flat_weights[weight_key].astype(mx.float32)
|
base_weight = flat_weights.pop(weight_key)
|
||||||
merged_weight = base_weight + delta
|
merged_weight = (base_weight.astype(mx.float32) + delta).astype(base_weight.dtype)
|
||||||
updates.append((weight_key, merged_weight.astype(mx.bfloat16)))
|
batch.append((weight_key, merged_weight))
|
||||||
|
del base_weight
|
||||||
merged_count += 1
|
merged_count += 1
|
||||||
|
|
||||||
model.load_weights(updates, strict=False)
|
if len(batch) >= batch_size:
|
||||||
|
model.load_weights(batch, strict=False)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
batch.clear()
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
model.load_weights(batch, strict=False)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
batch.clear()
|
||||||
|
|
||||||
|
del flat_weights, lora_weights
|
||||||
|
mx.clear_cache()
|
||||||
console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")
|
console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user