NextLaoHuang
commited on
Commit
•
5b6ab9c
1
Parent(s):
8620090
Upload merge_lora.py
Browse files- merge_lora.py +53 -0
merge_lora.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from typing import Dict
|
5 |
+
import typing
|
6 |
+
import torch
|
7 |
+
|
8 |
+
if '-h' in sys.argv or '--help' in sys.argv:
|
9 |
+
print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>')
|
10 |
+
|
11 |
+
if sys.argv[1] == '--use-gpu':
|
12 |
+
device = 'cuda'
|
13 |
+
lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5]
|
14 |
+
else:
|
15 |
+
device = 'cpu'
|
16 |
+
lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4]
|
17 |
+
|
18 |
+
|
19 |
+
with torch.no_grad():
|
20 |
+
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
|
21 |
+
# merge LoRA-only slim checkpoint into the main weights
|
22 |
+
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
23 |
+
for k in w_lora.keys():
|
24 |
+
w[k] = w_lora[k]
|
25 |
+
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
26 |
+
# merge LoRA weights
|
27 |
+
keys = list(w.keys())
|
28 |
+
for k in keys:
|
29 |
+
if k.endswith('.weight'):
|
30 |
+
prefix = k[:-len('.weight')]
|
31 |
+
lora_A = prefix + '.lora_A'
|
32 |
+
lora_B = prefix + '.lora_B'
|
33 |
+
if lora_A in keys:
|
34 |
+
assert lora_B in keys
|
35 |
+
print(f'merging {lora_A} and {lora_B} into {k}')
|
36 |
+
assert w[lora_B].shape[1] == w[lora_A].shape[0]
|
37 |
+
lora_r = w[lora_B].shape[1]
|
38 |
+
w[k] = w[k].to(device=device)
|
39 |
+
w[lora_A] = w[lora_A].to(device=device)
|
40 |
+
w[lora_B] = w[lora_B].to(device=device)
|
41 |
+
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
|
42 |
+
output_w[k] = w[k].to(device='cpu', copy=True)
|
43 |
+
del w[k]
|
44 |
+
del w[lora_A]
|
45 |
+
del w[lora_B]
|
46 |
+
continue
|
47 |
+
|
48 |
+
if 'lora' not in k:
|
49 |
+
print(f'retaining {k}')
|
50 |
+
output_w[k] = w[k].clone()
|
51 |
+
del w[k]
|
52 |
+
|
53 |
+
torch.save(output_w, output)
|