import json, copy from easydict import EasyDict import torch.nn as nn class DictX(dict): def __getattr__(self, key): try: return self[key] except KeyError as k: raise AttributeError(k) def __setattr__(self, key, value): self[key] = value def __delattr__(self, key): try: del self[key] except KeyError as k: raise AttributeError(k) def __repr__(self): return '' def load_hparams(file_path): hparams = EasyDict() with open(file_path, 'r') as f: hparams = json.load(f) return hparams def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model oldModuleList = model.encoder.layer newModuleList = nn.ModuleList() # Now iterate over all layers, only keepign only the relevant layers. for i in range(num_layers_to_keep): newModuleList.append(oldModuleList[i]) # create a copy of the model, modify it with the new list, and return copyOfModel = copy.deepcopy(model) copyOfModel.encoder.layer = newModuleList return copyOfModel