File size: 5,394 Bytes
d08dd00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Based on https://github.com/huggingface/transformers/issues/80

"""

import json
import argparse
import glob
import sys
import logging
import os
import time
import string
from filelock import FileLock

import numpy as np
import pickle
import torch
from torch.utils.data import DataLoader, TensorDataset

from .base import BaseModule, create_trainer
from ..data.examples import InputFeatures
from collections import ChainMap
from torch.utils.data import DataLoader, TensorDataset


logger = logging.getLogger(__name__)


class MaskedLM(BaseModule):

    mode = 'language-modeling'
    output_mode = 'classification'
    example_type = 'multiple-choice'

    def __init__(self, hparams):
        super().__init__(hparams)

        self.mask_id = self.tokenizer.convert_tokens_to_ids('[MASK]')
        self.test_results_fpath = 'test_results'
        if os.path.exists(self.test_results_fpath):
            os.remove(self.test_results_fpath)

    def convert_examples_to_features(self, examples):

        batch_encoding = self.tokenizer(
            [example.question for example in examples],
            max_length=self.hparams['max_seq_length'],
            padding='max_length',
            truncation=True,
        )

        features = []
        for i in range(len(examples)):
            inputs = {k: batch_encoding[k][i] for k in batch_encoding}
            candidates = examples[i].endings
            tokens = [self.tokenizer.tokenize(cand) for cand in candidates]
            token_candidates = []

            for toks in tokens:
                if len(toks) == 0:
                    token_candidates.append(self.tokenizer.unk_token)
                else:
                    token_candidates.append(max(toks, key=lambda t: len(t.strip(string.punctuation))))
            candidate_ids = self.tokenizer.convert_tokens_to_ids(token_candidates)

            feature = InputFeatures(**inputs, candidates=candidate_ids, label=examples[i].label)
            features.append(feature)

        return features

    def test_dataloader(self):
        mode = 'test'
        cached_features_file = self._feature_file(mode)
        if os.path.exists(cached_features_file) and not self.hparams['overwrite_cache']:
            features = torch.load(cached_features_file)
        else:
            features = self.load_features(mode)
            torch.save(features, cached_features_file)

        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids or 0 for f in features], dtype=torch.long)
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
        all_cands  = torch.tensor([f.candidates for f in features], dtype=torch.long)
        all_answers  = torch.tensor([f.label for f in features], dtype=torch.long)

        return DataLoader(
            TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_cands, all_answers),
            batch_size=self.hparams['eval_batch_size'],
        )

    def test_step(self, batch, batch_idx):
        inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
                  'attention_mask': batch[1]}

        answers = batch[3].detach().cpu().numpy()
        candidates = batch[4].detach().cpu().numpy()

        # get first mask location
        input_ids = batch[0].detach().cpu().numpy()
        mask_ids = (input_ids == self.mask_id).argmax(axis=1)
        mask_ids = torch.from_numpy(mask_ids)

        predictions = self(**inputs)[0]

        i = torch.arange(0, predictions.shape[0], dtype=torch.int64)
        predictions = predictions[i, mask_ids]
        predictions = predictions.detach().cpu().numpy()

        right, wrong = 0, 0

        for i, pred in enumerate(predictions):
            prob = pred[candidates[i]]
            pred_answer = int(np.argmax(prob))
            if answers[i] == pred_answer:
                right += 1
            else:
                wrong += 1

        return {"right": right, "wrong": wrong}

    def test_epoch_end(self, outputs):
        right = sum(output['right'] for output in outputs)
        wrong = sum(output['wrong'] for output in outputs)
        merged = {'right': right, 'wrong': wrong}

        with FileLock(self.test_results_fpath + '.lock'):
            if os.path.exists(self.test_results_fpath):
                with open(self.test_results_fpath, 'rb') as fp:
                    data = pickle.load(fp)
                data = {'right': data['right'] + merged['right'], 'wrong': data['wrong'] + merged['wrong']}
            else:
                data = merged 
            with open(self.test_results_fpath, 'wb') as fp:
                pickle.dump(data, fp)

        return data

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        return parser

    def run_module(self):
        self.eval()
        self.freeze()
        torch.no_grad()

        trainer = create_trainer(self, self.hparams)

        trainer.test(self)
        preds = pickle.load(open(self.test_results_fpath, 'rb'))
        correct, wrong = preds['right'], preds['wrong']
        with open(os.path.join(self.hparams['output_dir'], 'test_results.txt'), 'w') as fp:
            json.dump({'test_acc': correct/(correct + wrong)}, fp)