Jinkin commited on
Commit
680ab9f
1 Parent(s): bf14719

update eval scripts

Browse files
eval/cmteb_eval.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+
5
+ from mteb import MTEB
6
+ from sentence_transformers import SentenceTransformer
7
+ logging.basicConfig(level=logging.INFO)
8
+
9
+ logger = logging.getLogger("main")
10
+
11
+ CLASSIFICATION_LIST = ["TNews", "IFlyTek", "MultilingualSentiment", "JDReview", "OnlineShopping", "Waimai"]
12
+ STS_LIST = ["ATEC", "BQ", "LCQMC", "PAWSX", "STSB", "AFQMC", "QBQTC"]
13
+ PAIRCLASSIFICATION_LIST = ["Ocnli", "Cmnli"]
14
+ RERANKING_LIST = ["T2Reranking", "MmarcoReranking", "CMedQAv1", "CMedQAv2"]
15
+ CLUSTERING_LIST = ["CLSClusteringS2S", "CLSClusteringP2P", "ThuNewsClusteringS2S", "ThuNewsClusteringP2P"]
16
+ TASK_LIST = [CLASSIFICATION_LIST, STS_LIST, PAIRCLASSIFICATION_LIST, RERANKING_LIST, CLUSTERING_LIST]
17
+ names = ['Classification', 'STS', 'Pairclassification', 'Reranking', 'Clustering']
18
+
19
+ model = SentenceTransformer('piccolo-base-zh')
20
+ for name, task_list in zip(names, TASK_LIST):
21
+ for task in task_list:
22
+ logger.info(f"Running task: {task}")
23
+ evaluation = MTEB(tasks=[task])
24
+ evaluation.run(model, output_folder=f"results/{name}")
eval/cmteb_eval.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python cmteb_eval.py
eval/retrieval_eval.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''this eval code is borrowed from E5'''
2
+ import os
3
+ import json
4
+ import tqdm
5
+ import numpy as np
6
+ import torch
7
+ import argparse
8
+
9
+ from datasets import Dataset
10
+ from typing import List, Dict
11
+ from functools import partial
12
+ from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
13
+ from transformers.modeling_outputs import BaseModelOutput
14
+ from torch.utils.data import DataLoader
15
+ from mteb import MTEB, AbsTaskRetrieval, DRESModel
16
+
17
+ from utils import pool, logger, move_to_cuda
18
+
19
+ parser = argparse.ArgumentParser(description='evaluation for BEIR benchmark')
20
+ parser.add_argument('--model-name-or-path', default='bert-base-uncased',
21
+ type=str, metavar='N', help='which model to use')
22
+ parser.add_argument('--output-dir', default='tmp-outputs/',
23
+ type=str, metavar='N', help='output directory')
24
+ parser.add_argument('--pool-type', default='avg', help='pool type')
25
+ parser.add_argument('--max-length', default=512, help='max length')
26
+
27
+ args = parser.parse_args()
28
+ logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
29
+ assert args.pool_type in ['cls', 'avg'], 'pool_type should be cls or avg'
30
+ assert args.output_dir, 'output_dir should be set'
31
+ os.makedirs(args.output_dir, exist_ok=True)
32
+
33
+
34
+ def _transform_func(tokenizer: PreTrainedTokenizerFast,
35
+ examples: Dict[str, List]) -> BatchEncoding:
36
+ return tokenizer(examples['contents'],
37
+ max_length=int(args.max_length),
38
+ padding=True,
39
+ return_token_type_ids=False,
40
+ truncation=True)
41
+
42
+
43
+ class RetrievalModel(DRESModel):
44
+ # Refer to the code of DRESModel for the methods to overwrite
45
+ def __init__(self, **kwargs):
46
+ self.encoder = AutoModel.from_pretrained(args.model_name_or_path)
47
+ self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
48
+ self.gpu_count = torch.cuda.device_count()
49
+ if self.gpu_count > 1:
50
+ self.encoder = torch.nn.DataParallel(self.encoder)
51
+
52
+ self.encoder.cuda()
53
+ self.encoder.eval()
54
+
55
+ def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
56
+ input_texts = ['查询: {}'.format(q) for q in queries]
57
+ return self._do_encode(input_texts)
58
+
59
+ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
60
+ input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
61
+ input_texts = ['结果: {}'.format(t) for t in input_texts]
62
+ return self._do_encode(input_texts)
63
+
64
+ @torch.no_grad()
65
+ def _do_encode(self, input_texts: List[str]) -> np.ndarray:
66
+ dataset: Dataset = Dataset.from_dict({'contents': input_texts})
67
+ dataset.set_transform(partial(_transform_func, self.tokenizer))
68
+
69
+ data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8)
70
+ batch_size = 128 * self.gpu_count
71
+ data_loader = DataLoader(
72
+ dataset,
73
+ batch_size=batch_size,
74
+ shuffle=False,
75
+ drop_last=False,
76
+ num_workers=4,
77
+ collate_fn=data_collator,
78
+ pin_memory=True)
79
+
80
+ encoded_embeds = []
81
+ for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10):
82
+ batch_dict = move_to_cuda(batch_dict)
83
+
84
+ with torch.cuda.amp.autocast():
85
+ outputs: BaseModelOutput = self.encoder(**batch_dict)
86
+ embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type)
87
+ encoded_embeds.append(embeds.cpu().numpy())
88
+
89
+ return np.concatenate(encoded_embeds, axis=0)
90
+
91
+ TASKS = ["T2Retrieval", "MMarcoRetrieval", "DuRetrieval", "CovidRetrieval", "CmedqaRetrieval", "EcomRetrieval", "MedicalRetrieval", "VideoRetrieval"]
92
+ def main():
93
+ assert AbsTaskRetrieval.is_dres_compatible(RetrievalModel)
94
+ model = RetrievalModel()
95
+
96
+ task_names = [t.description["name"] for t in MTEB(tasks=TASKS).tasks]
97
+ logger.info('Tasks: {}'.format(task_names))
98
+
99
+ for task in task_names:
100
+ logger.info('Processing task: {}'.format(task))
101
+ evaluation = MTEB(tasks=[task])
102
+ evaluation.run(model, output_folder=args.output_dir, overwrite_results=False)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
eval/retrieval_eval.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -x
4
+ set -e
5
+
6
+ DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )"
7
+ echo "working directory: ${DIR}"
8
+
9
+ MODEL_NAME_OR_PATH="piccolo-base-zh"
10
+ OUTPUT_DIR='Retrieval'
11
+
12
+ mkdir -p "${OUTPUT_DIR}"
13
+
14
+ python -u retrieval_eval.py \
15
+ --model-name-or-path "${MODEL_NAME_OR_PATH}" \
16
+ --pool-type avg \
17
+ --output-dir "${OUTPUT_DIR}" "$@"