Spaces:
Running
Running
File size: 1,344 Bytes
7dd9869 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import json
from transformers import BertForMaskedLM, BertTokenizer
filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2'
bert_model = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model).cuda()
full_lst = []
with open(filename, 'r') as f:
for line in f:
line = json.loads(line)
full_lst.append(line)
for example in full_lst:
sent = example['sample']
obs1 = example['obs1']
obs2 = example['obs2']
if 'UNK' in sent:
sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2
print(sent)
model_inputs = tokenizer(sent,return_tensors="pt")
model_inputs = {k:v.to(model.device) for k,v in model_inputs.items()}
model_out = model(**model_inputs)
mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id
masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1))
if masked_logits.size(0) > 0:
# take argmax from this.
max_cands = torch.max(masked_logits, dim=-1)
indices = max_cands.indices
model_inputs['input_ids'][mask_words] = indices
print(tokenizer.batch_decode(model_inputs['input_ids'].tolist()))
else:
print('NO NEED THIS FIX. ')
|