File size: 3,442 Bytes
d08dd00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
"""
import logging
import json
import os
import pickle
import scipy.spatial as sp
from filelock import FileLock

import numpy as np
import torch

from .base import BaseModule, create_trainer


logger = logging.getLogger(__name__)


class XSentRetrieval(BaseModule):

    mode = 'base'
    output_mode = 'classification'
    example_type = 'text'

    def __init__(self, hparams):
        self.test_results_fpath = 'test_results'
        if os.path.exists(self.test_results_fpath):
            os.remove(self.test_results_fpath)

        super().__init__(hparams)

    def forward(self, **inputs):
        outputs = self.model(**inputs)
        last_hidden = outputs[0]
        mean_pooled = torch.mean(last_hidden, 1)
        return mean_pooled

    def test_dataloader_en(self):
        test_features = self.load_features('en')
        dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
        return dataloader

    def test_dataloader_in(self):
        test_features = self.load_features('in')
        dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
        return dataloader

    def test_step(self, batch, batch_idx):
        inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
                  'attention_mask': batch[1]}
        labels = batch[3].detach().cpu().numpy()
        sentvecs = self(**inputs)
        sentvecs = sentvecs.detach().cpu().numpy()
        sentvecs = np.hstack([labels[:, None], sentvecs])

        return {'sentvecs': sentvecs}

    def test_epoch_end(self, outputs):
        all_sentvecs = np.vstack([x['sentvecs'] for x in outputs])

        with FileLock(self.test_results_fpath + '.lock'):
            if os.path.exists(self.test_results_fpath):
                with open(self.test_results_fpath, 'rb') as fp:
                    data = pickle.load(fp)
                data = np.vstack([data, all_sentvecs])
            else:
                data = all_sentvecs
            with open(self.test_results_fpath, 'wb') as fp:
                pickle.dump(data, fp)

        return {'sentvecs': all_sentvecs}

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        return parser

    def run_module(self):
        self.eval()
        self.freeze()

        trainer = create_trainer(self, self.hparams)

        trainer.test(self, self.test_dataloader_en())
        sentvecs1 = pickle.load(open(self.test_results_fpath, 'rb'))
        os.remove(self.test_results_fpath)

        trainer.test(self, self.test_dataloader_in())
        sentvecs2 = pickle.load(open(self.test_results_fpath, 'rb'))
        os.remove(self.test_results_fpath)

        sentvecs1 = sentvecs1[sentvecs1[:, 0].argsort()][:, 1:]
        sentvecs2 = sentvecs2[sentvecs2[:, 0].argsort()][:, 1:]

        result_path = os.path.join(self.hparams['output_dir'], 'test_results.txt')
        with open(result_path, 'w') as fp:
            metrics = {'test_acc': precision_at_10(sentvecs1, sentvecs2)}
            json.dump(metrics, fp)


def precision_at_10(sentvecs1, sentvecs2):
    n = sentvecs1.shape[0]

    # mean centering
    sentvecs1 = sentvecs1 - np.mean(sentvecs1, axis=0)
    sentvecs2 = sentvecs2 - np.mean(sentvecs2, axis=0)

    sim = sp.distance.cdist(sentvecs1, sentvecs2, 'cosine')
    actual = np.array(range(n))
    preds = sim.argsort(axis=1)[:, :10]
    matches = np.any(preds == actual[:, None], axis=1)
    return matches.mean()