from everything import * from bert import BertModel def get_finetuned_bert(mode: str): assert mode in ['sup', 'unsup'] bert = BertModel.from_pretrained('bert-base-uncased') if mode == 'sup': state_dict = torch.load(SUP_BERT, weights_only=True) else: state_dict = torch.load(UNSUP_BERT, weights_only=True) device = torch.device('cuda') if USE_GPU else torch.device('cpu') bert.load_state_dict(state_dict) return bert.to(device)