File size: 1,167 Bytes
ecdea35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import json, copy
from easydict import EasyDict
import torch.nn as nn
class DictX(dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError as k:
raise AttributeError(k)
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
try:
del self[key]
except KeyError as k:
raise AttributeError(k)
def __repr__(self):
return '<DictX ' + dict.__repr__(self) + '>'
def load_hparams(file_path):
hparams = EasyDict()
with open(file_path, 'r') as f:
hparams = json.load(f)
return hparams
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
oldModuleList = model.encoder.layer
newModuleList = nn.ModuleList()
# Now iterate over all layers, only keepign only the relevant layers.
for i in range(num_layers_to_keep):
newModuleList.append(oldModuleList[i])
# create a copy of the model, modify it with the new list, and return
copyOfModel = copy.deepcopy(model)
copyOfModel.encoder.layer = newModuleList
return copyOfModel |