import torch import json from transformers import BertForMaskedLM, BertTokenizer filename = 'diffusion_lm/improved-diffusion/anlg_results/diff_roc_mbr.json2' bert_model = 'bert-base-uncased' tokenizer = BertTokenizer.from_pretrained(bert_model) model = BertForMaskedLM.from_pretrained(bert_model).cuda() full_lst = [] with open(filename, 'r') as f: for line in f: line = json.loads(line) full_lst.append(line) for example in full_lst: sent = example['sample'] obs1 = example['obs1'] obs2 = example['obs2'] if 'UNK' in sent: sent = obs1 + sent.replace('UNK', tokenizer.mask_token) + obs2 print(sent) model_inputs = tokenizer(sent,return_tensors="pt") model_inputs = {k:v.to(model.device) for k,v in model_inputs.items()} model_out = model(**model_inputs) mask_words = model_inputs['input_ids'] == tokenizer.mask_token_id masked_logits = model_out.logits[mask_words].view(-1, model_out.logits.size(-1)) if masked_logits.size(0) > 0: # take argmax from this. max_cands = torch.max(masked_logits, dim=-1) indices = max_cands.indices model_inputs['input_ids'][mask_words] = indices print(tokenizer.batch_decode(model_inputs['input_ids'].tolist())) else: print('NO NEED THIS FIX. ')