Spaces:
No application file
No application file
File size: 5,394 Bytes
d08dd00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
Based on https://github.com/huggingface/transformers/issues/80
"""
import json
import argparse
import glob
import sys
import logging
import os
import time
import string
from filelock import FileLock
import numpy as np
import pickle
import torch
from torch.utils.data import DataLoader, TensorDataset
from .base import BaseModule, create_trainer
from ..data.examples import InputFeatures
from collections import ChainMap
from torch.utils.data import DataLoader, TensorDataset
logger = logging.getLogger(__name__)
class MaskedLM(BaseModule):
mode = 'language-modeling'
output_mode = 'classification'
example_type = 'multiple-choice'
def __init__(self, hparams):
super().__init__(hparams)
self.mask_id = self.tokenizer.convert_tokens_to_ids('[MASK]')
self.test_results_fpath = 'test_results'
if os.path.exists(self.test_results_fpath):
os.remove(self.test_results_fpath)
def convert_examples_to_features(self, examples):
batch_encoding = self.tokenizer(
[example.question for example in examples],
max_length=self.hparams['max_seq_length'],
padding='max_length',
truncation=True,
)
features = []
for i in range(len(examples)):
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
candidates = examples[i].endings
tokens = [self.tokenizer.tokenize(cand) for cand in candidates]
token_candidates = []
for toks in tokens:
if len(toks) == 0:
token_candidates.append(self.tokenizer.unk_token)
else:
token_candidates.append(max(toks, key=lambda t: len(t.strip(string.punctuation))))
candidate_ids = self.tokenizer.convert_tokens_to_ids(token_candidates)
feature = InputFeatures(**inputs, candidates=candidate_ids, label=examples[i].label)
features.append(feature)
return features
def test_dataloader(self):
mode = 'test'
cached_features_file = self._feature_file(mode)
if os.path.exists(cached_features_file) and not self.hparams['overwrite_cache']:
features = torch.load(cached_features_file)
else:
features = self.load_features(mode)
torch.save(features, cached_features_file)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids or 0 for f in features], dtype=torch.long)
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
all_cands = torch.tensor([f.candidates for f in features], dtype=torch.long)
all_answers = torch.tensor([f.label for f in features], dtype=torch.long)
return DataLoader(
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_cands, all_answers),
batch_size=self.hparams['eval_batch_size'],
)
def test_step(self, batch, batch_idx):
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
'attention_mask': batch[1]}
answers = batch[3].detach().cpu().numpy()
candidates = batch[4].detach().cpu().numpy()
# get first mask location
input_ids = batch[0].detach().cpu().numpy()
mask_ids = (input_ids == self.mask_id).argmax(axis=1)
mask_ids = torch.from_numpy(mask_ids)
predictions = self(**inputs)[0]
i = torch.arange(0, predictions.shape[0], dtype=torch.int64)
predictions = predictions[i, mask_ids]
predictions = predictions.detach().cpu().numpy()
right, wrong = 0, 0
for i, pred in enumerate(predictions):
prob = pred[candidates[i]]
pred_answer = int(np.argmax(prob))
if answers[i] == pred_answer:
right += 1
else:
wrong += 1
return {"right": right, "wrong": wrong}
def test_epoch_end(self, outputs):
right = sum(output['right'] for output in outputs)
wrong = sum(output['wrong'] for output in outputs)
merged = {'right': right, 'wrong': wrong}
with FileLock(self.test_results_fpath + '.lock'):
if os.path.exists(self.test_results_fpath):
with open(self.test_results_fpath, 'rb') as fp:
data = pickle.load(fp)
data = {'right': data['right'] + merged['right'], 'wrong': data['wrong'] + merged['wrong']}
else:
data = merged
with open(self.test_results_fpath, 'wb') as fp:
pickle.dump(data, fp)
return data
@staticmethod
def add_model_specific_args(parser, root_dir):
return parser
def run_module(self):
self.eval()
self.freeze()
torch.no_grad()
trainer = create_trainer(self, self.hparams)
trainer.test(self)
preds = pickle.load(open(self.test_results_fpath, 'rb'))
correct, wrong = preds['right'], preds['wrong']
with open(os.path.join(self.hparams['output_dir'], 'test_results.txt'), 'w') as fp:
json.dump({'test_acc': correct/(correct + wrong)}, fp)
|