diff --git a/mlx_video/generate.py b/mlx_video/generate.py index fe3cbe9..f542abb 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -163,7 +163,7 @@ def load_and_merge_lora( else: sanitized_pairs = lora_pairs - # Get current model weights as a flat dict + # Get current model weights as a flat dict (references, not copies) def flatten_params(params, prefix=""): flat = {} for k, v in params.items(): @@ -176,9 +176,11 @@ def load_and_merge_lora( flat_weights = flatten_params(dict(model.parameters())) - # Merge LoRA deltas + # Merge LoRA deltas in batches to avoid doubling memory merged_count = 0 - updates = [] + batch = [] + batch_size = 100 # merge 100 weights at a time, then eval to free intermediates + for module_key, pair in sanitized_pairs.items(): if "A" not in pair or "B" not in pair: continue @@ -193,13 +195,24 @@ def load_and_merge_lora( # delta = (lora_B * strength) @ lora_A delta = (lora_b * strength) @ lora_a - base_weight = flat_weights[weight_key].astype(mx.float32) - merged_weight = base_weight + delta - updates.append((weight_key, merged_weight.astype(mx.bfloat16))) + base_weight = flat_weights.pop(weight_key) + merged_weight = (base_weight.astype(mx.float32) + delta).astype(base_weight.dtype) + batch.append((weight_key, merged_weight)) + del base_weight merged_count += 1 - model.load_weights(updates, strict=False) - mx.eval(model.parameters()) + if len(batch) >= batch_size: + model.load_weights(batch, strict=False) + mx.eval(model.parameters()) + batch.clear() + + if batch: + model.load_weights(batch, strict=False) + mx.eval(model.parameters()) + batch.clear() + + del flat_weights, lora_weights + mx.clear_cache() console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")