from everything import * | |
from bert import BertModel | |
def get_finetuned_bert(mode: str): | |
assert mode in ['sup', 'unsup'] | |
bert = BertModel.from_pretrained('bert-base-uncased') | |
if mode == 'sup': | |
state_dict = torch.load(SUP_BERT, weights_only=True) | |
else: | |
state_dict = torch.load(UNSUP_BERT, weights_only=True) | |
device = torch.device('cuda') if USE_GPU else torch.device('cpu') | |
bert.load_state_dict(state_dict) | |
return bert.to(device) | |