sunit333's picture
Upload 63 files
d08dd00 verified
raw
history blame
3.44 kB
"""
"""
import logging
import json
import os
import pickle
import scipy.spatial as sp
from filelock import FileLock
import numpy as np
import torch
from .base import BaseModule, create_trainer
logger = logging.getLogger(__name__)
class XSentRetrieval(BaseModule):
mode = 'base'
output_mode = 'classification'
example_type = 'text'
def __init__(self, hparams):
self.test_results_fpath = 'test_results'
if os.path.exists(self.test_results_fpath):
os.remove(self.test_results_fpath)
super().__init__(hparams)
def forward(self, **inputs):
outputs = self.model(**inputs)
last_hidden = outputs[0]
mean_pooled = torch.mean(last_hidden, 1)
return mean_pooled
def test_dataloader_en(self):
test_features = self.load_features('en')
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
return dataloader
def test_dataloader_in(self):
test_features = self.load_features('in')
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size'])
return dataloader
def test_step(self, batch, batch_idx):
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2],
'attention_mask': batch[1]}
labels = batch[3].detach().cpu().numpy()
sentvecs = self(**inputs)
sentvecs = sentvecs.detach().cpu().numpy()
sentvecs = np.hstack([labels[:, None], sentvecs])
return {'sentvecs': sentvecs}
def test_epoch_end(self, outputs):
all_sentvecs = np.vstack([x['sentvecs'] for x in outputs])
with FileLock(self.test_results_fpath + '.lock'):
if os.path.exists(self.test_results_fpath):
with open(self.test_results_fpath, 'rb') as fp:
data = pickle.load(fp)
data = np.vstack([data, all_sentvecs])
else:
data = all_sentvecs
with open(self.test_results_fpath, 'wb') as fp:
pickle.dump(data, fp)
return {'sentvecs': all_sentvecs}
@staticmethod
def add_model_specific_args(parser, root_dir):
return parser
def run_module(self):
self.eval()
self.freeze()
trainer = create_trainer(self, self.hparams)
trainer.test(self, self.test_dataloader_en())
sentvecs1 = pickle.load(open(self.test_results_fpath, 'rb'))
os.remove(self.test_results_fpath)
trainer.test(self, self.test_dataloader_in())
sentvecs2 = pickle.load(open(self.test_results_fpath, 'rb'))
os.remove(self.test_results_fpath)
sentvecs1 = sentvecs1[sentvecs1[:, 0].argsort()][:, 1:]
sentvecs2 = sentvecs2[sentvecs2[:, 0].argsort()][:, 1:]
result_path = os.path.join(self.hparams['output_dir'], 'test_results.txt')
with open(result_path, 'w') as fp:
metrics = {'test_acc': precision_at_10(sentvecs1, sentvecs2)}
json.dump(metrics, fp)
def precision_at_10(sentvecs1, sentvecs2):
n = sentvecs1.shape[0]
# mean centering
sentvecs1 = sentvecs1 - np.mean(sentvecs1, axis=0)
sentvecs2 = sentvecs2 - np.mean(sentvecs2, axis=0)
sim = sp.distance.cdist(sentvecs1, sentvecs2, 'cosine')
actual = np.array(range(n))
preds = sim.argsort(axis=1)[:, :10]
matches = np.any(preds == actual[:, None], axis=1)
return matches.mean()