anas-awadalla
commited on
Commit
•
28a1d28
1
Parent(s):
46b0f0f
removed lm weights from checkpoint
Browse files- checkpoint.pt +2 -2
- clean_checkpoint.py +12 -0
checkpoint.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2a7cc063ff3f02187dba34c45a7f58e60c291b7b37144a965edccd0c877c8f5a
|
3 |
+
size 4872875366
|
clean_checkpoint.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Load the checkpoint
|
4 |
+
checkpoint = torch.load('checkpoint_cleaned.pt', map_location=torch.device('cpu'))
|
5 |
+
print(checkpoint.keys())
|
6 |
+
# remove keys of fform lang_encoder.gpt_neox.layers.x.decoder_layer
|
7 |
+
for key in list(checkpoint.keys()):
|
8 |
+
if 'decoder_layer' in key:
|
9 |
+
del checkpoint[key]
|
10 |
+
|
11 |
+
# save the checkpoint
|
12 |
+
torch.save(checkpoint, 'checkpoint_cleaned.pt')
|