ianma2024 commited on
Commit
dd0b4f3
1 Parent(s): 2d4729d

upload codes and model weights

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 data-comment
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ListConRanker_ckpt/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "directionality": "bidi",
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "id2label": {
13
+ "0": "LABEL_0"
14
+ },
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 4096,
17
+ "label2id": {
18
+ "LABEL_0": 0
19
+ },
20
+ "layer_norm_eps": 1e-12,
21
+ "max_position_embeddings": 512,
22
+ "model_type": "bert",
23
+ "num_attention_heads": 16,
24
+ "num_hidden_layers": 24,
25
+ "output_hidden_states": true,
26
+ "pad_token_id": 0,
27
+ "pooler_fc_size": 768,
28
+ "pooler_num_attention_heads": 12,
29
+ "pooler_num_fc_layers": 3,
30
+ "pooler_size_per_head": 128,
31
+ "pooler_type": "first_token_transform",
32
+ "position_embedding_type": "absolute",
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.45.2",
35
+ "type_vocab_size": 2,
36
+ "use_cache": true,
37
+ "vocab_size": 21128
38
+ }
ListConRanker_ckpt/linear_in_embedding.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ae1bb10d5c23c3bdbe50dfee2f37bb243d1606cc1a41e02a8ffb7bf61b71033
3
+ size 7348826
ListConRanker_ckpt/list_transformer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cabbabfb04bc1feb6fa859a074f97397e88b0dc6bf14b0ef9ad3a0ddfac1cef5
3
+ size 293894397
ListConRanker_ckpt/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15c1c9ebc02bd5255758d0ba5498b3c93fd4cc8dd25845fd6a2cac8b2d12cefc
3
+ size 1302134568
ListConRanker_ckpt/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
ListConRanker_ckpt/tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "max_length": 512,
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "never_split": null,
52
+ "pad_to_multiple_of": null,
53
+ "pad_token": "[PAD]",
54
+ "pad_token_type_id": 0,
55
+ "padding_side": "right",
56
+ "sep_token": "[SEP]",
57
+ "stride": 0,
58
+ "strip_accents": null,
59
+ "tokenize_chinese_chars": true,
60
+ "tokenizer_class": "BertTokenizer",
61
+ "truncation_side": "right",
62
+ "truncation_strategy": "longest_first",
63
+ "unk_token": "[UNK]"
64
+ }
ListConRanker_ckpt/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,3 +1,164 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ model-index:
3
+ - name: ListConRanker
4
+ results:
5
+ - dataset:
6
+ config: default
7
+ name: MTEB CMedQAv1-reranking (default)
8
+ revision: null
9
+ split: test
10
+ type: C-MTEB/CMedQAv1-reranking
11
+ metrics:
12
+ - type: map
13
+ value: 90.55366308098787
14
+ - type: mrr_1
15
+ value: 87.8
16
+ - type: mrr_10
17
+ value: 92.45134920634919
18
+ - type: mrr_5
19
+ value: 92.325
20
+ - type: main_score
21
+ value: 90.55366308098787
22
+ task:
23
+ type: Reranking
24
+ - dataset:
25
+ config: default
26
+ name: MTEB CMedQAv2-reranking (default)
27
+ revision: null
28
+ split: test
29
+ type: C-MTEB/CMedQAv2-reranking
30
+ metrics:
31
+ - type: map
32
+ value: 89.38076135722042
33
+ - type: mrr_1
34
+ value: 85.9
35
+ - type: mrr_10
36
+ value: 91.28769841269842
37
+ - type: mrr_5
38
+ value: 91.08999999999999
39
+ - type: main_score
40
+ value: 89.38076135722042
41
+ task:
42
+ type: Reranking
43
+ - dataset:
44
+ config: default
45
+ name: MTEB MMarcoReranking (default)
46
+ revision: null
47
+ split: dev
48
+ type: C-MTEB/Mmarco-reranking
49
+ metrics:
50
+ - type: map
51
+ value: 43.881461866703894
52
+ - type: mrr_1
53
+ value: 32.0
54
+ - type: mrr_10
55
+ value: 44.700793650793656
56
+ - type: mrr_5
57
+ value: 43.61666666666667
58
+ - type: main_score
59
+ value: 43.881461866703894
60
+ task:
61
+ type: Reranking
62
+ - dataset:
63
+ config: default
64
+ name: MTEB T2Reranking (default)
65
+ revision: null
66
+ split: dev
67
+ type: C-MTEB/T2Reranking
68
+ metrics:
69
+ - type: map
70
+ value: 69.16513825032682
71
+ - type: mrr_1
72
+ value: 67.41706161137441
73
+ - type: mrr_10
74
+ value: 80.0946053776961
75
+ - type: mrr_5
76
+ value: 79.71676822387724
77
+ - type: main_score
78
+ value: 69.16513825032682
79
+ task:
80
+ type: Reranking
81
+ tags:
82
+ - mteb
83
+ ---
84
+
85
+
86
+ # ListConRanker
87
+ ## Model
88
+
89
+ - We propose a **List**wise-encoded **Con**trastive text re**Ranker** (**ListConRanker**), includes a ListTransformer module for listwise encoding. The ListTransformer can facilitate global contrastive information learning between passage features, including the clustering of similar passages, the clustering between dissimilar passages, and the distinction between similar and dissimilar passages. Besides, we propose ListAttention to help ListTransformer maintain the features of the query while learning global comparative information.
90
+ - The training loss function is Circle Loss[1]. Compared with cross-entropy loss and ranking loss, it can solve the problems of low data efficiency and unsmooth gradient change.
91
+
92
+ ## Data
93
+ The training data consists of approximately 2.6 million queries, each corresponding to multiple passages. The data comes from the training sets of several datasets, including cMedQA1.0, cMedQA2.0, MMarcoReranking, T2Reranking, huatuo, MARC, XL-sum, CSL and so on.
94
+
95
+ ## Training
96
+ We trained the model in two stages. In the first stage, we freeze the parameters of embedding model and only train the ListTransformer for 4 epochs with a batch size of 1024. In the second stage, we do not freeze any parameter and train for another 2 epochs with a batch size of 256.
97
+
98
+ ## Inference
99
+ Due to the limited memory of GPUs, we input about 20 passages at a time for each query during training. However, during actual use, there may be situations where far more than 20 passages are input at the same time (e.g, MMarcoReranking).
100
+
101
+ To reduce the discrepancy between training and inference, we propose iterative inference. The iterative inference feeds the passages into the ListConRanker multiple times, and each time it only decides the ranking of the passage at the end of the list.
102
+
103
+ ## Performance
104
+ | Model | cMedQA1.0 | cMedQA2.0 | MMarcoReranking | T2Reranking | Avg. |
105
+ | :--- | :---: | :---: | :---: | :---: | :---: |
106
+ | LdIR-Qwen2-reranker-1.5B | 86.50 | 87.11 | 39.35 | 68.84 | 70.45 |
107
+ | zpoint-large-embedding-zh | 91.11 | 90.07 | 38.87 | 69.29 | 72.34 |
108
+ | xiaobu-embedding-v2 | 90.96 | 90.41 | 39.91 | 69.03 | 72.58 |
109
+ | Conan-embedding-v1 | 91.39 | 89.72 | 41.58 | 68.36 | 72.76 |
110
+ | ListConRanker | 90.55 | 89.38 | 43.88 | 69.17 | **73.25** |
111
+ | - w/o Iterative Inference | 90.20 | 89.98 | 37.52 | 69.17 | 71.72 |
112
+
113
+ ## How to use
114
+ ```python
115
+ from modules.listconranker import ListConRanker
116
+
117
+ reranker = ListConRanker('./ListConRanker_ckpt', use_fp16=True, list_transformer_layer=2)
118
+
119
+ # [query, passages_1, passage_2, ..., passage_n]
120
+ batch = [
121
+ [
122
+ '皮蛋是寒性的食物吗', # query
123
+ '营养医师介绍皮蛋是属于凉性的食物,中医认为皮蛋可治眼疼、牙疼、高血压、耳鸣眩晕等疾病。体虚者要少吃。', # passage_1
124
+ '皮蛋这种食品是在中国地域才常见的传统食品,它的生长汗青也是非常的悠长。', # passage_2
125
+ '喜欢皮蛋的人会觉得皮蛋是最美味的食物,不喜欢皮蛋的人则觉得皮蛋是黑暗料理,尤其很多外国朋友都不理解我们吃皮蛋的习惯' # passage_3
126
+ ],
127
+ [
128
+ '月有阴晴圆缺的意义', # query
129
+ '形容的是月所有的状态,晴朗明媚,阴沉混沌,有月圆时,但多数时总是有缺陷。', # passage_1
130
+ '人有悲欢离合,月有阴晴圆缺这句话意思是人有悲欢离合的变迁,月有阴晴圆缺的转换。', # passage_2
131
+ '既然是诗歌,又哪里会有真正含义呢? 大概可以说:人生有太多坎坷,苦难,从容坦荡面对就好。', # passage_3
132
+ '一零七六年苏轼贬官密州,时年四十一岁的他政治上很不得志,时值中秋佳节,非常想念自己的弟弟子由内心颇感忧郁,情绪低沉,有感而发写了这首词。' # passage_4
133
+ ]
134
+ ]
135
+
136
+ # for conventional inference, please manage the batch size by yourself
137
+ scores = reranker.compute_score(batch)
138
+ print(scores)
139
+ # [[0.5126953125, 0.331298828125, 0.3642578125], [0.63671875, 0.71630859375, 0.42822265625, 0.35302734375]]
140
+
141
+ # for iterative inferfence, only a batch size of 1 is supported
142
+ # the scores do not indicate similarity but are intended only for ranking
143
+ scores = reranker.iterative_inference(batch[0])
144
+ print(scores)
145
+ # [0.5126953125, 0.331298828125, 0.3642578125]
146
+ ```
147
+
148
+ To reproduce the results with iterative inference, please run:
149
+ ```bash
150
+ python3 eval_listconranker_iterative_inference.py
151
+ ```
152
+
153
+ To reproduce the results without iterative inference, please run:
154
+ ```bash
155
+ python3 eval_listconranker.py
156
+ ```
157
+
158
+ ## Reference
159
+ 1. https://arxiv.org/abs/2002.10857
160
+ 2. https://github.com/FlagOpen/FlagEmbedding
161
+ 3. https://arxiv.org/abs/2408.15710
162
+
163
+ ## License
164
+ This work is licensed under a [MIT License](https://opensource.org/license/MIT) and the weight of models is licensed under a [Creative Commons Attribution-NonCommercial 4.0 International License](https://creativecommons.org/licenses/by-nc/4.0/).
eval_listconranker.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the “Software”), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import argparse
21
+ from modules.Reranking import *
22
+ from mteb import MTEB
23
+ from modules.listconranker import ListConRanker
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('--model_name_or_path', default="./ListConRanker_ckpt", type=str)
29
+ return parser.parse_args()
30
+
31
+
32
+ if __name__ == '__main__':
33
+ args = get_args()
34
+
35
+ model = ListConRanker(args.model_name_or_path, use_fp16=True, list_transformer_layer=2)
36
+ dir_name = args.model_name_or_path.split('/')[-2]
37
+ if 'checkpoint-' in args.model_name_or_path:
38
+ save_name = "_".join(args.model_name_or_path.split('/')[-2:])
39
+ dir_name = args.model_name_or_path.split('/')[-3]
40
+ else:
41
+ save_name = "_".join(args.model_name_or_path.split('/')[-1:])
42
+ dir_name = args.model_name_or_path.split('/')[-2]
43
+
44
+ evaluation = MTEB(task_types=["Reranking"], task_langs=['zh'])
45
+ evaluation.run(model, output_folder="reranker_results/{}/{}".format(dir_name, save_name))
eval_listconranker_iterative_inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the “Software”), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import argparse
21
+ from modules.Reranking_loop import *
22
+ from mteb import MTEB
23
+ from modules.listconranker import ListConRanker
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('--model_name_or_path', default="./ListConRanker_ckpt", type=str)
29
+ return parser.parse_args()
30
+
31
+
32
+ if __name__ == '__main__':
33
+ args = get_args()
34
+
35
+ model = ListConRanker(args.model_name_or_path, use_fp16=True, list_transformer_layer=2)
36
+ dir_name = args.model_name_or_path.split('/')[-2]
37
+ if 'checkpoint-' in args.model_name_or_path:
38
+ save_name = "_".join(args.model_name_or_path.split('/')[-2:])
39
+ dir_name = args.model_name_or_path.split('/')[-3]
40
+ else:
41
+ save_name = "_".join(args.model_name_or_path.split('/')[-1:])
42
+ dir_name = args.model_name_or_path.split('/')[-2]
43
+
44
+ evaluation = MTEB(task_types=["Reranking"], task_langs=['zh'])
45
+ evaluation.run(model, output_folder="reranker_results/{}/{}".format(dir_name, save_name))
modules/Reranking.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the “Software”), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import logging
21
+ import numpy as np
22
+ from mteb import RerankingEvaluator, AbsTaskReranking
23
+ from tqdm import tqdm
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ChineseRerankingEvaluator(RerankingEvaluator):
29
+ """
30
+ This class evaluates a SentenceTransformer model for the task of re-ranking.
31
+ Given a query and a list of documents, it computes the score [query, doc_i] for all possible
32
+ documents and sorts them in decreasing order. Then, MRR@10 and MAP is compute to measure the quality of the ranking.
33
+ :param samples: Must be a list and each element is of the form:
34
+ - {'query': '', 'positive': [], 'negative': []}. Query is the search query, positive is a list of positive
35
+ (relevant) documents, negative is a list of negative (irrelevant) documents.
36
+ - {'query': [], 'positive': [], 'negative': []}. Where query is a list of strings, which embeddings we average
37
+ to get the query embedding.
38
+ """
39
+
40
+ def __call__(self, model):
41
+ scores = self.compute_metrics(model)
42
+ return scores
43
+
44
+ def compute_metrics(self, model):
45
+ return (
46
+ self.compute_metrics_batched(model)
47
+ if self.use_batched_encoding
48
+ else self.compute_metrics_individual(model)
49
+ )
50
+
51
+ def compute_metrics_batched(self, model):
52
+ """
53
+ Computes the metrices in a batched way, by batching all queries and
54
+ all documents together
55
+ """
56
+
57
+ if hasattr(model, 'compute_score'):
58
+ return self.compute_metrics_batched_from_crossencoder(model)
59
+ else:
60
+ return self.compute_metrics_batched_from_biencoder(model)
61
+
62
+ def compute_metrics_batched_from_crossencoder(self, model):
63
+ batch_size = 4
64
+
65
+ all_ap_scores = []
66
+ all_mrr_1_scores = []
67
+ all_mrr_5_scores = []
68
+ all_mrr_10_scores = []
69
+
70
+ all_scores = []
71
+ tmp_pairs = []
72
+ for sample in tqdm(self.samples, desc="Evaluating"):
73
+ b_pairs = [sample['query']]
74
+ for p in sample['positive']:
75
+ b_pairs.append(p)
76
+ for n in sample['negative']:
77
+ b_pairs.append(n)
78
+ tmp_pairs.append(b_pairs)
79
+ if len(tmp_pairs) == batch_size:
80
+ sample_scores = model.compute_score(tmp_pairs)
81
+ sample_scores = sum(sample_scores, [])
82
+ all_scores += sample_scores
83
+ tmp_pairs = []
84
+ if len(tmp_pairs) > 0:
85
+ sample_scores = model.compute_score(tmp_pairs)
86
+ sample_scores = sum(sample_scores, [])
87
+ all_scores += sample_scores
88
+ all_scores = np.array(all_scores)
89
+
90
+ start_inx = 0
91
+ for sample in tqdm(self.samples, desc="Evaluating"):
92
+ is_relevant = [True] * len(sample['positive']) + [False] * len(sample['negative'])
93
+ pred_scores = all_scores[start_inx:start_inx + len(is_relevant)]
94
+ start_inx += len(is_relevant)
95
+ pred_scores_argsort = np.argsort(-pred_scores) # Sort in decreasing order
96
+
97
+ ap = self.ap_score(is_relevant, pred_scores)
98
+
99
+ mrr_1 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 1)
100
+ mrr_5 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 5)
101
+ mrr_10 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 10)
102
+
103
+ all_mrr_1_scores.append(mrr_1)
104
+ all_mrr_5_scores.append(mrr_5)
105
+ all_mrr_10_scores.append(mrr_10)
106
+ all_ap_scores.append(ap)
107
+
108
+ mean_ap = np.mean(all_ap_scores)
109
+ mean_mrr_1 = np.mean(all_mrr_1_scores)
110
+ mean_mrr_5 = np.mean(all_mrr_5_scores)
111
+ mean_mrr_10 = np.mean(all_mrr_10_scores)
112
+
113
+ return {"map": mean_ap, "mrr_1": mean_mrr_1, 'mrr_5': mean_mrr_5, 'mrr_10': mean_mrr_10}
114
+
115
+ def compute_metrics_batched_from_biencoder(self, model):
116
+ all_mrr_scores = []
117
+ all_ap_scores = []
118
+ logger.info("Encoding queries...")
119
+ if isinstance(self.samples[0]["query"], str):
120
+ if hasattr(model, 'encode_queries'):
121
+ all_query_embs = model.encode_queries(
122
+ [sample["query"] for sample in self.samples],
123
+ convert_to_tensor=True,
124
+ batch_size=self.batch_size,
125
+ )
126
+ else:
127
+ all_query_embs = model.encode(
128
+ [sample["query"] for sample in self.samples],
129
+ convert_to_tensor=True,
130
+ batch_size=self.batch_size,
131
+ )
132
+ elif isinstance(self.samples[0]["query"], list):
133
+ # In case the query is a list of strings, we get the most similar embedding to any of the queries
134
+ all_query_flattened = [q for sample in self.samples for q in sample["query"]]
135
+ if hasattr(model, 'encode_queries'):
136
+ all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
137
+ batch_size=self.batch_size)
138
+ else:
139
+ all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
140
+ else:
141
+ raise ValueError(f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}")
142
+
143
+ logger.info("Encoding candidates...")
144
+ all_docs = []
145
+ for sample in self.samples:
146
+ all_docs.extend(sample["positive"])
147
+ all_docs.extend(sample["negative"])
148
+
149
+ all_docs_embs = model.encode(all_docs, convert_to_tensor=True, batch_size=self.batch_size)
150
+
151
+ # Compute scores
152
+ logger.info("Evaluating...")
153
+ query_idx, docs_idx = 0, 0
154
+ for instance in self.samples:
155
+ num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
156
+ query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
157
+ query_idx += num_subqueries
158
+
159
+ num_pos = len(instance["positive"])
160
+ num_neg = len(instance["negative"])
161
+ docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
162
+ docs_idx += num_pos + num_neg
163
+
164
+ if num_pos == 0 or num_neg == 0:
165
+ continue
166
+
167
+ is_relevant = [True] * num_pos + [False] * num_neg
168
+
169
+ scores = self._compute_metrics_instance(query_emb, docs_emb, is_relevant)
170
+ all_mrr_scores.append(scores["mrr"])
171
+ all_ap_scores.append(scores["ap"])
172
+
173
+ mean_ap = np.mean(all_ap_scores)
174
+ mean_mrr = np.mean(all_mrr_scores)
175
+
176
+ return {"map": mean_ap, "mrr": mean_mrr}
177
+
178
+
179
+ def evaluate(self, model, split="test", **kwargs):
180
+ if not self.data_loaded:
181
+ self.load_data()
182
+
183
+ data_split = self.dataset[split]
184
+
185
+ evaluator = ChineseRerankingEvaluator(data_split, **kwargs)
186
+ scores = evaluator(model)
187
+
188
+ return dict(scores)
189
+
190
+
191
+ AbsTaskReranking.evaluate = evaluate
192
+
193
+
194
+ class T2Reranking(AbsTaskReranking):
195
+ @property
196
+ def description(self):
197
+ return {
198
+ 'name': 'T2Reranking',
199
+ 'hf_hub_name': "C-MTEB/T2Reranking",
200
+ 'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
201
+ "reference": "https://arxiv.org/abs/2304.03679",
202
+ 'type': 'Reranking',
203
+ 'category': 's2p',
204
+ 'eval_splits': ['dev'],
205
+ 'eval_langs': ['zh'],
206
+ 'main_score': 'map',
207
+ }
208
+
209
+
210
+ class T2RerankingZh2En(AbsTaskReranking):
211
+ @property
212
+ def description(self):
213
+ return {
214
+ 'name': 'T2RerankingZh2En',
215
+ 'hf_hub_name': "C-MTEB/T2Reranking_zh2en",
216
+ 'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
217
+ "reference": "https://arxiv.org/abs/2304.03679",
218
+ 'type': 'Reranking',
219
+ 'category': 's2p',
220
+ 'eval_splits': ['dev'],
221
+ 'eval_langs': ['zh2en'],
222
+ 'main_score': 'map',
223
+ }
224
+
225
+
226
+ class T2RerankingEn2Zh(AbsTaskReranking):
227
+ @property
228
+ def description(self):
229
+ return {
230
+ 'name': 'T2RerankingEn2Zh',
231
+ 'hf_hub_name': "C-MTEB/T2Reranking_en2zh",
232
+ 'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
233
+ "reference": "https://arxiv.org/abs/2304.03679",
234
+ 'type': 'Reranking',
235
+ 'category': 's2p',
236
+ 'eval_splits': ['dev'],
237
+ 'eval_langs': ['en2zh'],
238
+ 'main_score': 'map',
239
+ }
240
+
241
+
242
+ class MMarcoReranking(AbsTaskReranking):
243
+ @property
244
+ def description(self):
245
+ return {
246
+ 'name': 'MMarcoReranking',
247
+ 'hf_hub_name': "C-MTEB/Mmarco-reranking",
248
+ 'description': 'mMARCO is a multilingual version of the MS MARCO passage ranking dataset',
249
+ "reference": "https://github.com/unicamp-dl/mMARCO",
250
+ 'type': 'Reranking',
251
+ 'category': 's2p',
252
+ 'eval_splits': ['dev'],
253
+ 'eval_langs': ['zh'],
254
+ 'main_score': 'map',
255
+ }
256
+
257
+
258
+ class CMedQAv1(AbsTaskReranking):
259
+ @property
260
+ def description(self):
261
+ return {
262
+ 'name': 'CMedQAv1',
263
+ "hf_hub_name": "C-MTEB/CMedQAv1-reranking",
264
+ 'description': 'Chinese community medical question answering',
265
+ "reference": "https://github.com/zhangsheng93/cMedQA",
266
+ 'type': 'Reranking',
267
+ 'category': 's2p',
268
+ 'eval_splits': ['test'],
269
+ 'eval_langs': ['zh'],
270
+ 'main_score': 'map',
271
+ }
272
+
273
+
274
+ class CMedQAv2(AbsTaskReranking):
275
+ @property
276
+ def description(self):
277
+ return {
278
+ 'name': 'CMedQAv2',
279
+ "hf_hub_name": "C-MTEB/CMedQAv2-reranking",
280
+ 'description': 'Chinese community medical question answering',
281
+ "reference": "https://github.com/zhangsheng93/cMedQA2",
282
+ 'type': 'Reranking',
283
+ 'category': 's2p',
284
+ 'eval_splits': ['test'],
285
+ 'eval_langs': ['zh'],
286
+ 'main_score': 'map',
287
+ }
modules/Reranking_loop.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the “Software”), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import logging
21
+ import numpy as np
22
+ from mteb import RerankingEvaluator, AbsTaskReranking
23
+ from tqdm import tqdm
24
+ import math
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class ChineseRerankingEvaluator(RerankingEvaluator):
30
+ """
31
+ This class evaluates a SentenceTransformer model for the task of re-ranking.
32
+ Given a query and a list of documents, it computes the score [query, doc_i] for all possible
33
+ documents and sorts them in decreasing order. Then, MRR@10 and MAP is compute to measure the quality of the ranking.
34
+ :param samples: Must be a list and each element is of the form:
35
+ - {'query': '', 'positive': [], 'negative': []}. Query is the search query, positive is a list of positive
36
+ (relevant) documents, negative is a list of negative (irrelevant) documents.
37
+ - {'query': [], 'positive': [], 'negative': []}. Where query is a list of strings, which embeddings we average
38
+ to get the query embedding.
39
+ """
40
+
41
+ def __call__(self, model):
42
+ scores = self.compute_metrics(model)
43
+ return scores
44
+
45
+ def compute_metrics(self, model):
46
+ return (
47
+ self.compute_metrics_batched(model)
48
+ if self.use_batched_encoding
49
+ else self.compute_metrics_individual(model)
50
+ )
51
+
52
+ def compute_metrics_batched(self, model):
53
+ """
54
+ Computes the metrices in a batched way, by batching all queries and
55
+ all documents together
56
+ """
57
+
58
+ if hasattr(model, 'compute_score'):
59
+ return self.compute_metrics_batched_from_crossencoder(model)
60
+ else:
61
+ return self.compute_metrics_batched_from_biencoder(model)
62
+
63
+ def compute_metrics_batched_from_crossencoder(self, model):
64
+ all_ap_scores = []
65
+ all_mrr_1_scores = []
66
+ all_mrr_5_scores = []
67
+ all_mrr_10_scores = []
68
+
69
+ for sample in tqdm(self.samples, desc="Evaluating"):
70
+ query = sample['query']
71
+ pos = sample['positive']
72
+ neg = sample['negative']
73
+ passage = pos + neg
74
+ passage2label = {}
75
+ for p in pos:
76
+ passage2label[p] = True
77
+ for p in neg:
78
+ passage2label[p] = False
79
+
80
+ filter_times = 0
81
+ passage2score = {}
82
+ while len(passage) > 20:
83
+ batch = [[query] + passage]
84
+ pred_scores = model.compute_score(batch)[0]
85
+ # Sort in increasing order
86
+ pred_scores_argsort = np.argsort(pred_scores).tolist()
87
+ passage_len = len(passage)
88
+ to_filter_num = math.ceil(passage_len * 0.2)
89
+ if to_filter_num < 10:
90
+ to_filter_num = 10
91
+
92
+ have_filter_num = 0
93
+ while have_filter_num < to_filter_num:
94
+ idx = pred_scores_argsort[have_filter_num]
95
+ if passage[idx] in passage2score:
96
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
97
+ else:
98
+ passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
99
+ have_filter_num += 1
100
+ while pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]]:
101
+ idx = pred_scores_argsort[have_filter_num]
102
+ if passage[idx] in passage2score:
103
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
104
+ else:
105
+ passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
106
+ have_filter_num += 1
107
+ next_passage = []
108
+ next_passage_idx = have_filter_num
109
+ while next_passage_idx < len(passage):
110
+ idx = pred_scores_argsort[next_passage_idx]
111
+ next_passage.append(passage[idx])
112
+ next_passage_idx += 1
113
+ passage = next_passage
114
+ filter_times += 1
115
+
116
+ batch = [[query] + passage]
117
+ pred_scores = model.compute_score(batch)[0]
118
+ cnt = 0
119
+ while cnt < len(passage):
120
+ if passage[cnt] in passage2score:
121
+ passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
122
+ else:
123
+ passage2score[passage[cnt]] = [pred_scores[cnt] + filter_times]
124
+ cnt += 1
125
+
126
+ passage = list(set(pos + neg))
127
+ is_relevant = []
128
+ final_score = []
129
+ for i in range(len(passage)):
130
+ p = passage[i]
131
+ is_relevant += [passage2label[p]] * len(passage2score[p])
132
+ final_score += passage2score[p]
133
+
134
+ ap = self.ap_score(is_relevant, final_score)
135
+
136
+ pred_scores_argsort = np.argsort(-(np.array(final_score)))
137
+ mrr_1 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 1)
138
+ mrr_5 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 5)
139
+ mrr_10 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 10)
140
+
141
+ all_ap_scores.append(ap)
142
+ all_mrr_1_scores.append(mrr_1)
143
+ all_mrr_5_scores.append(mrr_5)
144
+ all_mrr_10_scores.append(mrr_10)
145
+
146
+ mean_ap = np.mean(all_ap_scores)
147
+ mean_mrr_1 = np.mean(all_mrr_1_scores)
148
+ mean_mrr_5 = np.mean(all_mrr_5_scores)
149
+ mean_mrr_10 = np.mean(all_mrr_10_scores)
150
+
151
+ return {"map": mean_ap, "mrr_1": mean_mrr_1, 'mrr_5': mean_mrr_5, 'mrr_10': mean_mrr_10}
152
+
153
+ def compute_metrics_batched_from_biencoder(self, model):
154
+ all_mrr_scores = []
155
+ all_ap_scores = []
156
+ logger.info("Encoding queries...")
157
+ if isinstance(self.samples[0]["query"], str):
158
+ if hasattr(model, 'encode_queries'):
159
+ all_query_embs = model.encode_queries(
160
+ [sample["query"] for sample in self.samples],
161
+ convert_to_tensor=True,
162
+ batch_size=self.batch_size,
163
+ )
164
+ else:
165
+ all_query_embs = model.encode(
166
+ [sample["query"] for sample in self.samples],
167
+ convert_to_tensor=True,
168
+ batch_size=self.batch_size,
169
+ )
170
+ elif isinstance(self.samples[0]["query"], list):
171
+ # In case the query is a list of strings, we get the most similar embedding to any of the queries
172
+ all_query_flattened = [q for sample in self.samples for q in sample["query"]]
173
+ if hasattr(model, 'encode_queries'):
174
+ all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
175
+ batch_size=self.batch_size)
176
+ else:
177
+ all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
178
+ else:
179
+ raise ValueError(f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}")
180
+
181
+ logger.info("Encoding candidates...")
182
+ all_docs = []
183
+ for sample in self.samples:
184
+ all_docs.extend(sample["positive"])
185
+ all_docs.extend(sample["negative"])
186
+
187
+ all_docs_embs = model.encode(all_docs, convert_to_tensor=True, batch_size=self.batch_size)
188
+
189
+ # Compute scores
190
+ logger.info("Evaluating...")
191
+ query_idx, docs_idx = 0, 0
192
+ for instance in self.samples:
193
+ num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
194
+ query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
195
+ query_idx += num_subqueries
196
+
197
+ num_pos = len(instance["positive"])
198
+ num_neg = len(instance["negative"])
199
+ docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
200
+ docs_idx += num_pos + num_neg
201
+
202
+ if num_pos == 0 or num_neg == 0:
203
+ continue
204
+
205
+ is_relevant = [True] * num_pos + [False] * num_neg
206
+
207
+ scores = self._compute_metrics_instance(query_emb, docs_emb, is_relevant)
208
+ all_mrr_scores.append(scores["mrr"])
209
+ all_ap_scores.append(scores["ap"])
210
+
211
+ mean_ap = np.mean(all_ap_scores)
212
+ mean_mrr = np.mean(all_mrr_scores)
213
+
214
+ return {"map": mean_ap, "mrr": mean_mrr}
215
+
216
+
217
+ def evaluate(self, model, split="test", **kwargs):
218
+ if not self.data_loaded:
219
+ self.load_data()
220
+
221
+ data_split = self.dataset[split]
222
+
223
+ evaluator = ChineseRerankingEvaluator(data_split, **kwargs)
224
+ scores = evaluator(model)
225
+
226
+ return dict(scores)
227
+
228
+
229
+ AbsTaskReranking.evaluate = evaluate
230
+
231
+
232
+ class T2Reranking(AbsTaskReranking):
233
+ @property
234
+ def description(self):
235
+ return {
236
+ 'name': 'T2Reranking',
237
+ 'hf_hub_name': "C-MTEB/T2Reranking",
238
+ 'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
239
+ "reference": "https://arxiv.org/abs/2304.03679",
240
+ 'type': 'Reranking',
241
+ 'category': 's2p',
242
+ 'eval_splits': ['dev'],
243
+ 'eval_langs': ['zh'],
244
+ 'main_score': 'map',
245
+ }
246
+
247
+
248
+ class T2RerankingZh2En(AbsTaskReranking):
249
+ @property
250
+ def description(self):
251
+ return {
252
+ 'name': 'T2RerankingZh2En',
253
+ 'hf_hub_name': "C-MTEB/T2Reranking_zh2en",
254
+ 'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
255
+ "reference": "https://arxiv.org/abs/2304.03679",
256
+ 'type': 'Reranking',
257
+ 'category': 's2p',
258
+ 'eval_splits': ['dev'],
259
+ 'eval_langs': ['zh2en'],
260
+ 'main_score': 'map',
261
+ }
262
+
263
+
264
+ class T2RerankingEn2Zh(AbsTaskReranking):
265
+ @property
266
+ def description(self):
267
+ return {
268
+ 'name': 'T2RerankingEn2Zh',
269
+ 'hf_hub_name': "C-MTEB/T2Reranking_en2zh",
270
+ 'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
271
+ "reference": "https://arxiv.org/abs/2304.03679",
272
+ 'type': 'Reranking',
273
+ 'category': 's2p',
274
+ 'eval_splits': ['dev'],
275
+ 'eval_langs': ['en2zh'],
276
+ 'main_score': 'map',
277
+ }
278
+
279
+
280
+ class MMarcoReranking(AbsTaskReranking):
281
+ @property
282
+ def description(self):
283
+ return {
284
+ 'name': 'MMarcoReranking',
285
+ 'hf_hub_name': "C-MTEB/Mmarco-reranking",
286
+ 'description': 'mMARCO is a multilingual version of the MS MARCO passage ranking dataset',
287
+ "reference": "https://github.com/unicamp-dl/mMARCO",
288
+ 'type': 'Reranking',
289
+ 'category': 's2p',
290
+ 'eval_splits': ['dev'],
291
+ 'eval_langs': ['zh'],
292
+ 'main_score': 'map',
293
+ }
294
+
295
+
296
+ class CMedQAv1(AbsTaskReranking):
297
+ @property
298
+ def description(self):
299
+ return {
300
+ 'name': 'CMedQAv1',
301
+ "hf_hub_name": "C-MTEB/CMedQAv1-reranking",
302
+ 'description': 'Chinese community medical question answering',
303
+ "reference": "https://github.com/zhangsheng93/cMedQA",
304
+ 'type': 'Reranking',
305
+ 'category': 's2p',
306
+ 'eval_splits': ['test'],
307
+ 'eval_langs': ['zh'],
308
+ 'main_score': 'map',
309
+ }
310
+
311
+
312
+ class CMedQAv2(AbsTaskReranking):
313
+ @property
314
+ def description(self):
315
+ return {
316
+ 'name': 'CMedQAv2',
317
+ "hf_hub_name": "C-MTEB/CMedQAv2-reranking",
318
+ 'description': 'Chinese community medical question answering',
319
+ "reference": "https://github.com/zhangsheng93/cMedQA2",
320
+ 'type': 'Reranking',
321
+ 'category': 's2p',
322
+ 'eval_splits': ['test'],
323
+ 'eval_langs': ['zh'],
324
+ 'main_score': 'map',
325
+ }
modules/listconranker.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the “Software”), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import math
21
+ import torch
22
+ import numpy as np
23
+ from transformers import AutoTokenizer, is_torch_npu_available
24
+ from typing import Union, List
25
+ from .modeling import CrossEncoder
26
+
27
+ import os
28
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
29
+
30
+
31
+ def sigmoid(x):
32
+ return 1 / (1 + np.exp(-x))
33
+
34
+
35
+ class ListConRanker:
36
+ def __init__(
37
+ self,
38
+ model_name_or_path: str = None,
39
+ use_fp16: bool = False,
40
+ cache_dir: str = None,
41
+ device: Union[str, int] = None,
42
+ list_transformer_layer = None
43
+ ) -> None:
44
+
45
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
46
+ self.model = CrossEncoder.from_pretrained_for_eval(model_name_or_path, list_transformer_layer)
47
+
48
+ if device and isinstance(device, str):
49
+ self.device = torch.device(device)
50
+ if device == 'cpu':
51
+ use_fp16 = False
52
+ else:
53
+ if torch.cuda.is_available():
54
+ if device is not None:
55
+ self.device = torch.device(f"cuda:{device}")
56
+ else:
57
+ self.device = torch.device("cuda")
58
+ elif torch.backends.mps.is_available():
59
+ self.device = torch.device("mps")
60
+ elif is_torch_npu_available():
61
+ self.device = torch.device("npu")
62
+ else:
63
+ self.device = torch.device("cpu")
64
+ use_fp16 = False
65
+ if use_fp16:
66
+ self.model.half()
67
+
68
+ self.model = self.model.to(self.device)
69
+
70
+ self.model.eval()
71
+
72
+ if device is None:
73
+ self.num_gpus = torch.cuda.device_count()
74
+ if self.num_gpus > 1:
75
+ print(f"----------using {self.num_gpus}*GPUs----------")
76
+ self.model = torch.nn.DataParallel(self.model)
77
+ else:
78
+ self.num_gpus = 1
79
+
80
+ @torch.no_grad()
81
+ def compute_score(self, sentence_pairs: List[List[str]], max_length: int = 512) -> List[List[float]]:
82
+ pair_nums = [len(pairs) - 1 for pairs in sentence_pairs]
83
+ sentences_batch = sum(sentence_pairs, [])
84
+ inputs = self.tokenizer(
85
+ sentences_batch,
86
+ padding=True,
87
+ truncation=True,
88
+ return_tensors='pt',
89
+ max_length=max_length,
90
+ ).to(self.device)
91
+ inputs['pair_num'] = torch.LongTensor(pair_nums)
92
+ scores = self.model(inputs).float()
93
+ all_scores = scores.cpu().numpy().tolist()
94
+
95
+ if isinstance(all_scores, float):
96
+ return [all_scores]
97
+ result = []
98
+ curr_idx = 0
99
+ for i in range(len(pair_nums)):
100
+ result.append(all_scores[curr_idx: curr_idx + pair_nums[i]])
101
+ curr_idx += pair_nums[i]
102
+ # return all_scores
103
+ return result
104
+
105
+ @torch.no_grad()
106
+ def iterative_inference(self, sentence_pairs: List[str], max_length: int = 512) -> List[float]:
107
+ query = sentence_pairs[0]
108
+ passage = sentence_pairs[1:]
109
+
110
+ filter_times = 0
111
+ passage2score = {}
112
+ while len(passage) > 20:
113
+ batch = [[query] + passage]
114
+ pred_scores = self.compute_score(batch, max_length)[0]
115
+ # Sort in increasing order
116
+ pred_scores_argsort = np.argsort(pred_scores).tolist()
117
+ passage_len = len(passage)
118
+ to_filter_num = math.ceil(passage_len * 0.2)
119
+ if to_filter_num < 10:
120
+ to_filter_num = 10
121
+
122
+ have_filter_num = 0
123
+ while have_filter_num < to_filter_num:
124
+ idx = pred_scores_argsort[have_filter_num]
125
+ if passage[idx] in passage2score:
126
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
127
+ else:
128
+ passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
129
+ have_filter_num += 1
130
+ while pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]]:
131
+ idx = pred_scores_argsort[have_filter_num]
132
+ if passage[idx] in passage2score:
133
+ passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
134
+ else:
135
+ passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
136
+ have_filter_num += 1
137
+ next_passage = []
138
+ next_passage_idx = have_filter_num
139
+ while next_passage_idx < len(passage):
140
+ idx = pred_scores_argsort[next_passage_idx]
141
+ next_passage.append(passage[idx])
142
+ next_passage_idx += 1
143
+ passage = next_passage
144
+ filter_times += 1
145
+
146
+ batch = [[query] + passage]
147
+ pred_scores = self.compute_score(batch, max_length)[0]
148
+ cnt = 0
149
+ while cnt < len(passage):
150
+ if passage[cnt] in passage2score:
151
+ passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
152
+ else:
153
+ passage2score[passage[cnt]] = [pred_scores[cnt] + filter_times]
154
+ cnt += 1
155
+
156
+ passage = sentence_pairs[1:]
157
+ final_score = []
158
+ for i in range(len(passage)):
159
+ p = passage[i]
160
+ final_score += passage2score[p]
161
+ return final_score
modules/modeling.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software
4
+ # and associated documentation files (the “Software”), to deal in the Software without
5
+ # restriction, including without limitation the rights to use, copy, modify, merge, publish,
6
+ # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
7
+ # Software is furnished to do so, subject to the following conditions:
8
+ #
9
+ # The above copyright notice and this permission notice shall be included in all copies or
10
+ # substantial portions of the Software.
11
+ #
12
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
13
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
14
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
16
+ # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
17
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
18
+ # OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import logging
21
+ import torch
22
+ from torch import nn
23
+ from transformers import AutoModel, PreTrainedModel
24
+ from torch.nn import functional as F
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class ListTransformer(nn.Module):
31
+ def __init__(self, num_layer, config, device) -> None:
32
+ super().__init__()
33
+ self.config = config
34
+ self.device = device
35
+ self.list_transformer_layer = nn.TransformerEncoderLayer(1792, self.config.num_attention_heads, batch_first=True, activation=F.gelu, norm_first=False)
36
+ self.list_transformer = nn.TransformerEncoder(self.list_transformer_layer, num_layer)
37
+ self.relu = nn.ReLU()
38
+ self.query_embedding = QueryEmbedding(config, device)
39
+
40
+ self.linear_score3 = nn.Linear(1792 * 2, 1792)
41
+ self.linear_score2 = nn.Linear(1792 * 2, 1792)
42
+ self.linear_score1 = nn.Linear(1792 * 2, 1)
43
+
44
+ def forward(self, pair_features, pair_nums):
45
+ pair_nums = [x + 1 for x in pair_nums]
46
+ batch_pair_features = pair_features.split(pair_nums)
47
+
48
+ pair_feature_query_passage_concat_list = []
49
+ for i in range(len(batch_pair_features)):
50
+ pair_feature_query = batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
51
+ pair_feature_passage = batch_pair_features[i][1:]
52
+ pair_feature_query_passage_concat_list.append(torch.cat([pair_feature_query, pair_feature_passage], dim=1))
53
+ pair_feature_query_passage_concat = torch.cat(pair_feature_query_passage_concat_list, dim=0)
54
+
55
+ batch_pair_features = nn.utils.rnn.pad_sequence(batch_pair_features, batch_first=True)
56
+
57
+ query_embedding_tags = torch.zeros(batch_pair_features.size(0), batch_pair_features.size(1), dtype=torch.long, device=self.device)
58
+ query_embedding_tags[:, 0] = 1
59
+ batch_pair_features = self.query_embedding(batch_pair_features, query_embedding_tags)
60
+
61
+ mask = self.generate_attention_mask(pair_nums)
62
+ query_mask = self.generate_attention_mask_custom(pair_nums)
63
+ pair_list_features = self.list_transformer(batch_pair_features, src_key_padding_mask=mask, mask=query_mask)
64
+
65
+ output_pair_list_features = []
66
+ output_query_list_features = []
67
+ pair_features_after_transformer_list = []
68
+ for idx, pair_num in enumerate(pair_nums):
69
+ output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
70
+ output_query_list_features.append(pair_list_features[idx, 0, :])
71
+ pair_features_after_transformer_list.append(pair_list_features[idx, :pair_num, :])
72
+
73
+ pair_features_after_transformer_cat_query_list = []
74
+ for idx, pair_num in enumerate(pair_nums):
75
+ query_ft = output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
76
+ pair_features_after_transformer_cat_query = torch.cat([query_ft, output_pair_list_features[idx]], dim=1)
77
+ pair_features_after_transformer_cat_query_list.append(pair_features_after_transformer_cat_query)
78
+ pair_features_after_transformer_cat_query = torch.cat(pair_features_after_transformer_cat_query_list, dim=0)
79
+
80
+ pair_feature_query_passage_concat = self.relu(self.linear_score2(pair_feature_query_passage_concat))
81
+ pair_features_after_transformer_cat_query = self.relu(self.linear_score3(pair_features_after_transformer_cat_query))
82
+ final_ft = torch.cat([pair_feature_query_passage_concat, pair_features_after_transformer_cat_query], dim=1)
83
+ logits = self.linear_score1(final_ft).squeeze()
84
+
85
+ return logits, torch.cat(pair_features_after_transformer_list, dim=0)
86
+
87
+ def generate_attention_mask(self, pair_num):
88
+ max_len = max(pair_num)
89
+ batch_size = len(pair_num)
90
+ mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device)
91
+ for i, length in enumerate(pair_num):
92
+ mask[i, length:] = True
93
+ return mask
94
+
95
+ def generate_attention_mask_custom(self, pair_num):
96
+ max_len = max(pair_num)
97
+
98
+ mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device)
99
+ mask[0, 1:] = True
100
+
101
+ return mask
102
+
103
+
104
+ class QueryEmbedding(nn.Module):
105
+ def __init__(self, config, device) -> None:
106
+ super().__init__()
107
+ self.query_embedding = nn.Embedding(2, 1792)
108
+ self.layerNorm = nn.LayerNorm(1792)
109
+
110
+ def forward(self, x, tags):
111
+ query_embeddings = self.query_embedding(tags)
112
+ x += query_embeddings
113
+ x = self.layerNorm(x)
114
+ return x
115
+
116
+
117
+ class CrossEncoder(nn.Module):
118
+ def __init__(self, hf_model: PreTrainedModel, list_transformer_layer_4eval: int=None):
119
+ super().__init__()
120
+ self.hf_model = hf_model
121
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+ self.sigmoid = nn.Sigmoid()
123
+
124
+ self.config = self.hf_model.config
125
+ self.config.output_hidden_states = True
126
+
127
+ self.linear_in_embedding = nn.Linear(1024, 1792)
128
+ self.list_transformer_layer = list_transformer_layer_4eval
129
+ self.list_transformer = ListTransformer(self.list_transformer_layer, self.config, self.device)
130
+
131
+ def forward(self, batch):
132
+ if 'pair_num' in batch:
133
+ pair_nums = batch.pop('pair_num').tolist()
134
+
135
+ if self.training:
136
+ pass
137
+ else:
138
+ split_batch = 400
139
+ input_ids = batch['input_ids']
140
+ attention_mask = batch['attention_mask']
141
+ if sum(pair_nums) > split_batch:
142
+ last_hidden_state_list = []
143
+ input_ids_list = input_ids.split(split_batch)
144
+ attention_mask_list = attention_mask.split(split_batch)
145
+ for i in range(len(input_ids_list)):
146
+ last_hidden_state = self.hf_model(input_ids=input_ids_list[i], attention_mask=attention_mask_list[i], return_dict=True).hidden_states[-1]
147
+ last_hidden_state_list.append(last_hidden_state)
148
+ last_hidden_state = torch.cat(last_hidden_state_list, dim=0)
149
+ else:
150
+ ranker_out = self.hf_model(**batch, return_dict=True)
151
+ last_hidden_state = ranker_out.last_hidden_state
152
+
153
+ pair_features = self.average_pooling(last_hidden_state, attention_mask)
154
+ pair_features = self.linear_in_embedding(pair_features)
155
+
156
+ logits, pair_features_after_list_transformer = self.list_transformer(pair_features, pair_nums)
157
+ logits = self.sigmoid(logits)
158
+
159
+ return logits
160
+
161
+ @classmethod
162
+ def from_pretrained_for_eval(cls, model_name_or_path, list_transformer_layer):
163
+ hf_model = AutoModel.from_pretrained(model_name_or_path)
164
+ reranker = cls(hf_model, list_transformer_layer)
165
+ reranker.linear_in_embedding.load_state_dict(torch.load(model_name_or_path + '/linear_in_embedding.pt'))
166
+ reranker.list_transformer.load_state_dict(torch.load(model_name_or_path + '/list_transformer.pt'))
167
+ return reranker
168
+
169
+ def average_pooling(self, hidden_state, attention_mask):
170
+ extended_attention_mask = attention_mask.unsqueeze(-1).expand(hidden_state.size()).to(dtype=hidden_state.dtype)
171
+ masked_hidden_state = hidden_state * extended_attention_mask
172
+ sum_embeddings = torch.sum(masked_hidden_state, dim=1)
173
+ sum_mask = extended_attention_mask.sum(dim=1)
174
+ return sum_embeddings / sum_mask
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mteb==1.1.1
2
+ torch==2.1.2
3
+ tqdm==4.67.0
4
+ transformers==4.46.2