File size: 1,344 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import json
from transformers import BertForMaskedLM, BertTokenizer
filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2'
bert_model = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model)
model = BertForMaskedLM.from_pretrained(bert_model).cuda()

full_lst = []
with open(filename, 'r') as f:
    for line in f:
        line = json.loads(line)
        full_lst.append(line)

for example in full_lst:
    sent = example['sample']
    obs1 = example['obs1']
    obs2 = example['obs2']
    if 'UNK' in sent:
        sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2
        print(sent)
        model_inputs = tokenizer(sent,return_tensors="pt")
        model_inputs = {k:v.to(model.device) for k,v in model_inputs.items()}
        model_out = model(**model_inputs)
        mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id
        masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1))
        if masked_logits.size(0) > 0:
            # take argmax from this.
            max_cands = torch.max(masked_logits, dim=-1)
            indices = max_cands.indices
        model_inputs['input_ids'][mask_words] = indices
        print(tokenizer.batch_decode(model_inputs['input_ids'].tolist()))
    else:
        print('NO NEED THIS FIX. ')