NextLaoHuang commited on
Commit
5b6ab9c
1 Parent(s): 8620090

Upload merge_lora.py

Browse files
Files changed (1) hide show
  1. merge_lora.py +53 -0
merge_lora.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import os
3
+ import sys
4
+ from typing import Dict
5
+ import typing
6
+ import torch
7
+
8
+ if '-h' in sys.argv or '--help' in sys.argv:
9
+ print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>')
10
+
11
+ if sys.argv[1] == '--use-gpu':
12
+ device = 'cuda'
13
+ lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5]
14
+ else:
15
+ device = 'cpu'
16
+ lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4]
17
+
18
+
19
+ with torch.no_grad():
20
+ w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
21
+ # merge LoRA-only slim checkpoint into the main weights
22
+ w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
23
+ for k in w_lora.keys():
24
+ w[k] = w_lora[k]
25
+ output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
26
+ # merge LoRA weights
27
+ keys = list(w.keys())
28
+ for k in keys:
29
+ if k.endswith('.weight'):
30
+ prefix = k[:-len('.weight')]
31
+ lora_A = prefix + '.lora_A'
32
+ lora_B = prefix + '.lora_B'
33
+ if lora_A in keys:
34
+ assert lora_B in keys
35
+ print(f'merging {lora_A} and {lora_B} into {k}')
36
+ assert w[lora_B].shape[1] == w[lora_A].shape[0]
37
+ lora_r = w[lora_B].shape[1]
38
+ w[k] = w[k].to(device=device)
39
+ w[lora_A] = w[lora_A].to(device=device)
40
+ w[lora_B] = w[lora_B].to(device=device)
41
+ w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
42
+ output_w[k] = w[k].to(device='cpu', copy=True)
43
+ del w[k]
44
+ del w[lora_A]
45
+ del w[lora_B]
46
+ continue
47
+
48
+ if 'lora' not in k:
49
+ print(f'retaining {k}')
50
+ output_w[k] = w[k].clone()
51
+ del w[k]
52
+
53
+ torch.save(output_w, output)