File size: 833 Bytes
31b540f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import sys
from safetensors.torch import save_file


model = torch.load('pytorch_model.bin', map_location="cpu")
for key, value in model.items():
    print(key)
# sys.exit()  # Comment this out after you've found the key you wish to stop at, and also the final key you wish to keep.


stop_key = 'h.12.ln_1.weight'  # The key you want to stop at (the previous key is kept)
final_keys = ['ln_f.weight', 'ln_f.bias']  # The final key/keys in the model which get saved.


stripped = {}
stopped = False
for key, value in model.items():
    if key == stop_key:
        stopped = True
        continue

    if key in final_keys:
       stripped[key] = value

    if stopped is False:
        stripped[key] = value


save_file(stripped, 'pytorch_model_stripped.safetensors')
torch.save(stripped, 'pytorch_model_stripped.bin')