Spaces:
No application file
No application file
""" | |
""" | |
import logging | |
import json | |
import os | |
import pickle | |
import scipy.spatial as sp | |
from filelock import FileLock | |
import numpy as np | |
import torch | |
from .base import BaseModule, create_trainer | |
logger = logging.getLogger(__name__) | |
class XSentRetrieval(BaseModule): | |
mode = 'base' | |
output_mode = 'classification' | |
example_type = 'text' | |
def __init__(self, hparams): | |
self.test_results_fpath = 'test_results' | |
if os.path.exists(self.test_results_fpath): | |
os.remove(self.test_results_fpath) | |
super().__init__(hparams) | |
def forward(self, **inputs): | |
outputs = self.model(**inputs) | |
last_hidden = outputs[0] | |
mean_pooled = torch.mean(last_hidden, 1) | |
return mean_pooled | |
def test_dataloader_en(self): | |
test_features = self.load_features('en') | |
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size']) | |
return dataloader | |
def test_dataloader_in(self): | |
test_features = self.load_features('in') | |
dataloader = self.make_loader(test_features, self.hparams['eval_batch_size']) | |
return dataloader | |
def test_step(self, batch, batch_idx): | |
inputs = {'input_ids': batch[0], 'token_type_ids': batch[2], | |
'attention_mask': batch[1]} | |
labels = batch[3].detach().cpu().numpy() | |
sentvecs = self(**inputs) | |
sentvecs = sentvecs.detach().cpu().numpy() | |
sentvecs = np.hstack([labels[:, None], sentvecs]) | |
return {'sentvecs': sentvecs} | |
def test_epoch_end(self, outputs): | |
all_sentvecs = np.vstack([x['sentvecs'] for x in outputs]) | |
with FileLock(self.test_results_fpath + '.lock'): | |
if os.path.exists(self.test_results_fpath): | |
with open(self.test_results_fpath, 'rb') as fp: | |
data = pickle.load(fp) | |
data = np.vstack([data, all_sentvecs]) | |
else: | |
data = all_sentvecs | |
with open(self.test_results_fpath, 'wb') as fp: | |
pickle.dump(data, fp) | |
return {'sentvecs': all_sentvecs} | |
def add_model_specific_args(parser, root_dir): | |
return parser | |
def run_module(self): | |
self.eval() | |
self.freeze() | |
trainer = create_trainer(self, self.hparams) | |
trainer.test(self, self.test_dataloader_en()) | |
sentvecs1 = pickle.load(open(self.test_results_fpath, 'rb')) | |
os.remove(self.test_results_fpath) | |
trainer.test(self, self.test_dataloader_in()) | |
sentvecs2 = pickle.load(open(self.test_results_fpath, 'rb')) | |
os.remove(self.test_results_fpath) | |
sentvecs1 = sentvecs1[sentvecs1[:, 0].argsort()][:, 1:] | |
sentvecs2 = sentvecs2[sentvecs2[:, 0].argsort()][:, 1:] | |
result_path = os.path.join(self.hparams['output_dir'], 'test_results.txt') | |
with open(result_path, 'w') as fp: | |
metrics = {'test_acc': precision_at_10(sentvecs1, sentvecs2)} | |
json.dump(metrics, fp) | |
def precision_at_10(sentvecs1, sentvecs2): | |
n = sentvecs1.shape[0] | |
# mean centering | |
sentvecs1 = sentvecs1 - np.mean(sentvecs1, axis=0) | |
sentvecs2 = sentvecs2 - np.mean(sentvecs2, axis=0) | |
sim = sp.distance.cdist(sentvecs1, sentvecs2, 'cosine') | |
actual = np.array(range(n)) | |
preds = sim.argsort(axis=1)[:, :10] | |
matches = np.any(preds == actual[:, None], axis=1) | |
return matches.mean() | |