|
import json, copy |
|
from easydict import EasyDict |
|
|
|
import torch.nn as nn |
|
|
|
class DictX(dict): |
|
def __getattr__(self, key): |
|
try: |
|
return self[key] |
|
except KeyError as k: |
|
raise AttributeError(k) |
|
|
|
def __setattr__(self, key, value): |
|
self[key] = value |
|
|
|
def __delattr__(self, key): |
|
try: |
|
del self[key] |
|
except KeyError as k: |
|
raise AttributeError(k) |
|
|
|
def __repr__(self): |
|
return '<DictX ' + dict.__repr__(self) + '>' |
|
|
|
|
|
def load_hparams(file_path): |
|
hparams = EasyDict() |
|
with open(file_path, 'r') as f: |
|
hparams = json.load(f) |
|
return hparams |
|
|
|
|
|
def deleteEncodingLayers(model, num_layers_to_keep): |
|
oldModuleList = model.encoder.layer |
|
newModuleList = nn.ModuleList() |
|
|
|
|
|
for i in range(num_layers_to_keep): |
|
newModuleList.append(oldModuleList[i]) |
|
|
|
|
|
copyOfModel = copy.deepcopy(model) |
|
copyOfModel.encoder.layer = newModuleList |
|
|
|
return copyOfModel |