sunit333's picture
Upload 63 files
d08dd00 verified
raw
history blame
No virus
5.39 kB
"""
Based on https://github.com/huggingface/transformers/issues/80
"""
import json
import argparse
import glob
import sys
import logging
import os
import time
import string
from filelock import FileLock
import numpy as np
import pickle
import torch
from torch.utils.data import DataLoader, TensorDataset
from .base import BaseModule, create_trainer
from ..data.examples import InputFeatures
from collections import ChainMap
from torch.utils.data import DataLoader, TensorDataset
logger = logging.getLogger(__name__)
class MaskedLM(BaseModule):
mode = 'language-modeling'
output_mode = 'classification'
example_type = 'multiple-choice'
def __init__(self, hparams):
super().__init__(hparams)
self.mask_id = self.tokenizer.convert_tokens_to_ids('[MASK]')
self.test_results_fpath = 'test_results'
if os.path.exists(self.test_results_fpath):
os.remove(self.test_results_fpath)
def convert_examples_to_features(self, examples):
batch_encoding = self.tokenizer(
[example.question for example in examples],
max_length=self.hparams['max_seq_length'],
padding='max_length',
truncation=True,
)
features = []
for i in range(len(examples)):
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
candidates = examples[i].endings
tokens = [self.tokenizer.tokenize(cand) for cand in candidates]
token_candidates = []
for toks in tokens:
if len(toks) == 0:
token_candidates.append(self.tokenizer.unk_token)
else:
token_candidates.append(max(toks, key=lambda t: len(t.strip(string.punctuation))))
candidate_ids = self.tokenizer.convert_tokens_to_ids(token_candidates)
feature = InputFeatures(**inputs, candidates=candidate_ids, label=examples[i].label)
features.append(feature)
return features
def test_dataloader(self):
mode = 'test'
cached_features_file = self._feature_file(mode)
if os.path.exists(cached_features_file) and not self.hparams['overwrite_cache']:
features = torch.load(cached_features_file)
else:
features = self.load_features(mode)
torch.save(features, cached_features_file)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids or 0 for f in features], dtype=torch.long)
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
all_cands = torch.tensor([f.candidates for f in features], dtype=torch.long)
all_answers = torch.tensor([f.label for f in features], dtype=torch.long)
return DataLoader(
TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_cands, all_answers),
batch_size=self.hparams['eval_batch_size'],
)
def test_step(self, batch, batch_idx):
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
'attention_mask': batch[1]}
answers = batch[3].detach().cpu().numpy()
candidates = batch[4].detach().cpu().numpy()
# get first mask location
input_ids = batch[0].detach().cpu().numpy()
mask_ids = (input_ids == self.mask_id).argmax(axis=1)
mask_ids = torch.from_numpy(mask_ids)
predictions = self(**inputs)[0]
i = torch.arange(0, predictions.shape[0], dtype=torch.int64)
predictions = predictions[i, mask_ids]
predictions = predictions.detach().cpu().numpy()
right, wrong = 0, 0
for i, pred in enumerate(predictions):
prob = pred[candidates[i]]
pred_answer = int(np.argmax(prob))
if answers[i] == pred_answer:
right += 1
else:
wrong += 1
return {"right": right, "wrong": wrong}
def test_epoch_end(self, outputs):
right = sum(output['right'] for output in outputs)
wrong = sum(output['wrong'] for output in outputs)
merged = {'right': right, 'wrong': wrong}
with FileLock(self.test_results_fpath + '.lock'):
if os.path.exists(self.test_results_fpath):
with open(self.test_results_fpath, 'rb') as fp:
data = pickle.load(fp)
data = {'right': data['right'] + merged['right'], 'wrong': data['wrong'] + merged['wrong']}
else:
data = merged
with open(self.test_results_fpath, 'wb') as fp:
pickle.dump(data, fp)
return data
@staticmethod
def add_model_specific_args(parser, root_dir):
return parser
def run_module(self):
self.eval()
self.freeze()
torch.no_grad()
trainer = create_trainer(self, self.hparams)
trainer.test(self)
preds = pickle.load(open(self.test_results_fpath, 'rb'))
correct, wrong = preds['right'], preds['wrong']
with open(os.path.join(self.hparams['output_dir'], 'test_results.txt'), 'w') as fp:
json.dump({'test_acc': correct/(correct + wrong)}, fp)