minBERT / train_scripts /finetuned_bert.py
GlowCheese's picture
Final touch
8eff58f
raw
history blame contribute delete
474 Bytes
from everything import *
from bert import BertModel
def get_finetuned_bert(mode: str):
assert mode in ['sup', 'unsup']
bert = BertModel.from_pretrained('bert-base-uncased')
if mode == 'sup':
state_dict = torch.load(SUP_BERT, weights_only=True)
else:
state_dict = torch.load(UNSUP_BERT, weights_only=True)
device = torch.device('cuda') if USE_GPU else torch.device('cpu')
bert.load_state_dict(state_dict)
return bert.to(device)