BiBERTa / util /utils.py
jinysun's picture
Upload 17 files
ecdea35
raw
history blame
1.17 kB
import json, copy
from easydict import EasyDict
import torch.nn as nn
class DictX(dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError as k:
raise AttributeError(k)
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
try:
del self[key]
except KeyError as k:
raise AttributeError(k)
def __repr__(self):
return '<DictX ' + dict.__repr__(self) + '>'
def load_hparams(file_path):
hparams = EasyDict()
with open(file_path, 'r') as f:
hparams = json.load(f)
return hparams
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
oldModuleList = model.encoder.layer
newModuleList = nn.ModuleList()
# Now iterate over all layers, only keepign only the relevant layers.
for i in range(num_layers_to_keep):
newModuleList.append(oldModuleList[i])
# create a copy of the model, modify it with the new list, and return
copyOfModel = copy.deepcopy(model)
copyOfModel.encoder.layer = newModuleList
return copyOfModel