|
'''this eval code is borrowed from E5''' |
|
import os |
|
import json |
|
import tqdm |
|
import numpy as np |
|
import torch |
|
import argparse |
|
|
|
from datasets import Dataset |
|
from typing import List, Dict |
|
from functools import partial |
|
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding |
|
from transformers.modeling_outputs import BaseModelOutput |
|
from torch.utils.data import DataLoader |
|
from mteb import MTEB, AbsTaskRetrieval, DRESModel |
|
|
|
from utils import pool, logger, move_to_cuda |
|
|
|
parser = argparse.ArgumentParser(description='evaluation for BEIR benchmark') |
|
parser.add_argument('--model-name-or-path', default='bert-base-uncased', |
|
type=str, metavar='N', help='which model to use') |
|
parser.add_argument('--output-dir', default='tmp-outputs/', |
|
type=str, metavar='N', help='output directory') |
|
parser.add_argument('--pool-type', default='avg', help='pool type') |
|
parser.add_argument('--max-length', default=512, help='max length') |
|
|
|
args = parser.parse_args() |
|
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) |
|
assert args.pool_type in ['cls', 'avg'], 'pool_type should be cls or avg' |
|
assert args.output_dir, 'output_dir should be set' |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
def _transform_func(tokenizer: PreTrainedTokenizerFast, |
|
examples: Dict[str, List]) -> BatchEncoding: |
|
return tokenizer(examples['contents'], |
|
max_length=int(args.max_length), |
|
padding=True, |
|
return_token_type_ids=False, |
|
truncation=True) |
|
|
|
|
|
class RetrievalModel(DRESModel): |
|
|
|
def __init__(self, **kwargs): |
|
self.encoder = AutoModel.from_pretrained(args.model_name_or_path) |
|
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) |
|
self.gpu_count = torch.cuda.device_count() |
|
if self.gpu_count > 1: |
|
self.encoder = torch.nn.DataParallel(self.encoder) |
|
|
|
self.encoder.cuda() |
|
self.encoder.eval() |
|
|
|
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: |
|
input_texts = ['查询: {}'.format(q) for q in queries] |
|
return self._do_encode(input_texts) |
|
|
|
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray: |
|
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus] |
|
input_texts = ['结果: {}'.format(t) for t in input_texts] |
|
return self._do_encode(input_texts) |
|
|
|
@torch.no_grad() |
|
def _do_encode(self, input_texts: List[str]) -> np.ndarray: |
|
dataset: Dataset = Dataset.from_dict({'contents': input_texts}) |
|
dataset.set_transform(partial(_transform_func, self.tokenizer)) |
|
|
|
data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) |
|
batch_size = 128 * self.gpu_count |
|
data_loader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=4, |
|
collate_fn=data_collator, |
|
pin_memory=True) |
|
|
|
encoded_embeds = [] |
|
for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10): |
|
batch_dict = move_to_cuda(batch_dict) |
|
|
|
with torch.cuda.amp.autocast(): |
|
outputs: BaseModelOutput = self.encoder(**batch_dict) |
|
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) |
|
encoded_embeds.append(embeds.cpu().numpy()) |
|
|
|
return np.concatenate(encoded_embeds, axis=0) |
|
|
|
TASKS = ["T2Retrieval", "MMarcoRetrieval", "DuRetrieval", "CovidRetrieval", "CmedqaRetrieval", "EcomRetrieval", "MedicalRetrieval", "VideoRetrieval"] |
|
def main(): |
|
assert AbsTaskRetrieval.is_dres_compatible(RetrievalModel) |
|
model = RetrievalModel() |
|
|
|
task_names = [t.description["name"] for t in MTEB(tasks=TASKS).tasks] |
|
logger.info('Tasks: {}'.format(task_names)) |
|
|
|
for task in task_names: |
|
logger.info('Processing task: {}'.format(task)) |
|
evaluation = MTEB(tasks=[task]) |
|
evaluation.run(model, output_folder=args.output_dir, overwrite_results=False) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|