nreimers commited on
Commit
e6d00b4
·
1 Parent(s): f7ae382
CERerankingEvaluator_results.csv ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,steps,MRR@10
2
+ 0,5000,0.5220761904761905
3
+ 0,10000,0.5474539682539683
4
+ 0,15000,0.5190095238095239
5
+ 0,20000,0.5671206349206349
6
+ 0,25000,0.5630952380952381
7
+ 0,30000,0.565215873015873
8
+ 0,35000,0.5821079365079366
9
+ 0,40000,0.5496603174603175
10
+ 0,45000,0.5716761904761904
11
+ 0,50000,0.5707079365079364
12
+ 0,55000,0.5709047619047619
13
+ 0,60000,0.5643650793650793
14
+ 0,65000,0.5586507936507936
15
+ 0,70000,0.5792857142857142
16
+ 0,75000,0.5964857142857143
17
+ 0,80000,0.5759936507936507
18
+ 0,85000,0.5640507936507936
19
+ 0,90000,0.6016063492063491
20
+ 0,95000,0.594984126984127
21
+ 0,100000,0.5770507936507936
22
+ 0,105000,0.605984126984127
23
+ 0,110000,0.6106380952380953
24
+ 0,115000,0.5763650793650794
25
+ 0,120000,0.5977269841269841
26
+ 0,125000,0.5764190476190475
27
+ 0,130000,0.5846825396825397
28
+ 0,135000,0.5810380952380951
29
+ 0,140000,0.5902317460317461
30
+ 0,145000,0.6034063492063492
31
+ 0,150000,0.5953714285714286
32
+ 0,155000,0.5992349206349207
33
+ 0,160000,0.6026666666666668
34
+ 0,165000,0.6046603174603176
35
+ 0,170000,0.5939269841269842
36
+ 0,175000,0.6007714285714286
37
+ 0,180000,0.574752380952381
38
+ 0,185000,0.5923619047619049
39
+ 0,190000,0.600431746031746
40
+ 0,195000,0.6104984126984127
41
+ 0,200000,0.6154095238095239
42
+ 0,205000,0.5908285714285714
43
+ 0,210000,0.590936507936508
44
+ 0,215000,0.6043174603174603
45
+ 0,220000,0.6032825396825396
46
+ 0,225000,0.6210666666666667
47
+ 0,230000,0.6113396825396825
48
+ 0,235000,0.6135873015873016
49
+ 0,240000,0.6162285714285715
50
+ 0,245000,0.6064317460317461
51
+ 0,250000,0.6072285714285715
52
+ 0,255000,0.6073746031746032
53
+ 0,260000,0.6112857142857142
54
+ 0,265000,0.6156412698412698
55
+ 0,270000,0.6350095238095238
56
+ 0,275000,0.6074158730158731
57
+ 0,280000,0.6154761904761905
58
+ 0,285000,0.6236507936507937
59
+ 0,290000,0.6162412698412697
60
+ 0,295000,0.616184126984127
61
+ 0,300000,0.5997523809523809
62
+ 0,305000,0.5937492063492064
63
+ 0,310000,0.6227968253968252
64
+ 0,315000,0.6274952380952382
65
+ 0,320000,0.6269523809523809
66
+ 0,325000,0.6306698412698413
67
+ 0,330000,0.6235079365079363
68
+ 0,335000,0.6206190476190477
69
+ 0,340000,0.6209936507936508
70
+ 0,345000,0.613095238095238
71
+ 0,350000,0.6196952380952381
72
+ 0,355000,0.6197301587301588
73
+ 0,360000,0.6274634920634922
74
+ 0,365000,0.6152730158730159
75
+ 0,370000,0.6053968253968254
76
+ 0,375000,0.615352380952381
77
+ 0,380000,0.6110285714285715
78
+ 0,385000,0.621184126984127
79
+ 0,390000,0.6025619047619047
80
+ 0,395000,0.6122507936507936
81
+ 0,400000,0.6189079365079365
82
+ 0,405000,0.6252285714285714
83
+ 0,410000,0.6022634920634921
84
+ 0,415000,0.6053492063492063
85
+ 0,420000,0.6239619047619047
86
+ 0,425000,0.6127523809523809
87
+ 0,430000,0.6231873015873016
88
+ 0,435000,0.6233968253968254
89
+ 0,440000,0.6186825396825396
90
+ 0,445000,0.6279079365079365
91
+ 0,450000,0.6075079365079366
92
+ 0,455000,0.603352380952381
93
+ 0,460000,0.5917142857142857
94
+ 0,465000,0.5998285714285714
95
+ 0,470000,0.5949492063492064
96
+ 0,475000,0.6139714285714285
97
+ 0,480000,0.6100507936507936
98
+ 0,485000,0.6057619047619048
99
+ 0,490000,0.6255714285714286
100
+ 0,495000,0.6058158730158729
101
+ 0,500000,0.63004126984127
102
+ 0,505000,0.6207269841269841
103
+ 0,510000,0.6126857142857143
104
+ 0,515000,0.6224825396825397
105
+ 0,520000,0.6282730158730159
106
+ 0,525000,0.6256634920634919
107
+ 0,530000,0.6199079365079365
108
+ 0,535000,0.6065555555555556
109
+ 0,540000,0.6166158730158731
110
+ 0,545000,0.6133936507936507
111
+ 0,550000,0.6265428571428572
112
+ 0,555000,0.6077619047619048
113
+ 0,560000,0.6010984126984126
114
+ 0,565000,0.6134158730158731
115
+ 0,570000,0.6211714285714286
116
+ 0,575000,0.6167301587301588
117
+ 0,580000,0.6193968253968253
118
+ 0,585000,0.605352380952381
119
+ 0,590000,0.6013523809523811
120
+ 0,595000,0.6070285714285715
121
+ 0,600000,0.6075492063492064
122
+ 0,605000,0.6051396825396825
123
+ 0,610000,0.609984126984127
124
+ 0,615000,0.6076412698412699
125
+ 0,620000,0.604384126984127
126
+ 0,625000,0.6051396825396825
127
+ 0,-1,0.6051396825396825
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cross-Encoder for MS Marco
2
+
3
+ This model uses [TinyBERT](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT), a tiny BERT model with only 4 layers. The base model is: General_TinyBERT_v2(4layer-312dim)
4
+
5
+ It was trained on [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
6
+
7
+ The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Information Retrieval](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/information-retrieval) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
8
+
9
+ ## Usage and Performance
10
+
11
+ Pre-trained models can be used like this:
12
+ ```
13
+ from sentence_transformers import CrossEncoder
14
+ model = CrossEncoder('model_name', max_length=512)
15
+ scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
16
+ ```
17
+
18
+ In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
19
+
20
+
21
+ | Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec (BertTokenizerFast) | Docs / Sec |
22
+ | ------------- |:-------------| -----| --- | --- |
23
+ | cross-encoder/ms-marco-TinyBERT-L-2 | 67.43 | 30.15 | 9000 | 780
24
+ | cross-encoder/ms-marco-TinyBERT-L-4 | 68.09 | 34.50 | 2900 | 760
25
+ | cross-encoder/ms-marco-TinyBERT-L-6 | 69.57 | 36.13 | 680 | 660
26
+ | cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340 | 340
27
+ | *Other models* | | | |
28
+ | nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900 | 760
29
+ | nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340 | 340|
30
+ | nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100 | 100 |
31
+ | Capreolus/electra-base-msmarco | 71.23 | | 340 | 340 |
32
+ | amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | | 330 | 330
33
+
34
+ Note: Runtime was computed on a V100 GPU. A bottleneck for smaller models is the standard Python tokenizer from Huggingface in version 3. Replacing it with the fast tokenizer based on Rust, the throughput is significantly improved:
35
+
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nreimers/TinyBERT_L-4_H-312_v2",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 312,
11
+ "id2label": {
12
+ "0": "LABEL_0"
13
+ },
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 1200,
16
+ "label2id": {
17
+ "LABEL_0": 0
18
+ },
19
+ "layer_norm_eps": 1e-12,
20
+ "max_position_embeddings": 512,
21
+ "model_type": "bert",
22
+ "num_attention_heads": 12,
23
+ "num_hidden_layers": 4,
24
+ "pad_token_id": 0,
25
+ "type_vocab_size": 2,
26
+ "vocab_size": 30522
27
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b06785cbb737ac18adf49a4ac3ce4d724479afc9c133d5e83172b12db996dc91
3
+ size 57436041
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": "/home/ukp-reimers/.cache/torch/transformers/f96b11e14fec8f4be06121e7f6bbe07f82216bf7d75ad76fe3a81251e8895d69.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d", "tokenizer_file": null, "name_or_path": "nreimers/TinyBERT_L-4_H-312_v2"}
train_script.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from sentence_transformers import LoggingHandler
3
+ from sentence_transformers.cross_encoder import CrossEncoder
4
+ from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
5
+ from sentence_transformers import InputExample
6
+ import logging
7
+ from datetime import datetime
8
+ import gzip
9
+ import sys
10
+ import numpy as np
11
+ import os
12
+ from shutil import copyfile
13
+ import csv
14
+ import tqdm
15
+
16
+ #### Just some code to print debug information to stdout
17
+ logging.basicConfig(format='%(asctime)s - %(message)s',
18
+ datefmt='%Y-%m-%d %H:%M:%S',
19
+ level=logging.INFO,
20
+ handlers=[LoggingHandler()])
21
+ #### /print debug information to stdout
22
+
23
+
24
+ #Define our Cross-Encoder
25
+ model_name = sys.argv[1] #'google/electra-small-discriminator'
26
+ train_batch_size = 32
27
+ num_epochs = 1
28
+ model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
29
+
30
+ #We set num_labels=1, which predicts a continous score between 0 and 1
31
+ model = CrossEncoder(model_name, num_labels=1, max_length=512)
32
+
33
+
34
+ # Write self to path
35
+ os.makedirs(model_save_path, exist_ok=True)
36
+
37
+ train_script_path = os.path.join(model_save_path, 'train_script.py')
38
+ copyfile(__file__, train_script_path)
39
+ with open(train_script_path, 'a') as fOut:
40
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
41
+
42
+
43
+ corpus = {}
44
+ queries = {}
45
+
46
+ #### Read train file
47
+ with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
48
+ for line in fIn:
49
+ pid, passage = line.strip().split("\t")
50
+ corpus[pid] = passage
51
+
52
+ with open('../data/queries.train.tsv', 'r') as fIn:
53
+ for line in fIn:
54
+ qid, query = line.strip().split("\t")
55
+ queries[qid] = query
56
+
57
+
58
+
59
+ pos_neg_ration = (4+1)
60
+ cnt = 0
61
+ train_samples = []
62
+ dev_samples = {}
63
+
64
+ num_dev_queries = 125
65
+ num_max_dev_negatives = 200
66
+
67
+ with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
68
+ for line in fIn:
69
+ qid, pos_id, neg_id = line.strip().split()
70
+
71
+ if qid not in dev_samples and len(dev_samples) < num_dev_queries:
72
+ dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
73
+
74
+ if qid in dev_samples:
75
+ dev_samples[qid]['positive'].add(corpus[pos_id])
76
+
77
+ if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
78
+ dev_samples[qid]['negative'].add(corpus[neg_id])
79
+
80
+ with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
81
+ for line in tqdm.tqdm(fIn, unit_scale=True):
82
+ cnt += 1
83
+ qid, pos_id, neg_id = line.strip().split()
84
+ query = queries[qid]
85
+ if (cnt % pos_neg_ration) == 0:
86
+ passage = corpus[pos_id]
87
+ label = 1
88
+ else:
89
+ passage = corpus[neg_id]
90
+ label = 0
91
+
92
+ train_samples.append(InputExample(texts=[query, passage], label=label))
93
+
94
+ if len(train_samples) >= 2e7:
95
+ break
96
+
97
+
98
+
99
+ train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
100
+
101
+ # We add an evaluator, which evaluates the performance during training
102
+
103
+ class CERerankingEvaluator:
104
+ def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
105
+ self.samples = samples
106
+ self.name = name
107
+ self.mrr_at_k = mrr_at_k
108
+
109
+ if isinstance(self.samples, dict):
110
+ self.samples = list(self.samples.values())
111
+
112
+ self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
113
+ self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
114
+
115
+ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
116
+ if epoch != -1:
117
+ if steps == -1:
118
+ out_txt = " after epoch {}:".format(epoch)
119
+ else:
120
+ out_txt = " in epoch {} after {} steps:".format(epoch, steps)
121
+ else:
122
+ out_txt = ":"
123
+
124
+ logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
125
+
126
+ all_mrr_scores = []
127
+ num_queries = 0
128
+ num_positives = []
129
+ num_negatives = []
130
+ for instance in self.samples:
131
+ query = instance['query']
132
+ positive = list(instance['positive'])
133
+ negative = list(instance['negative'])
134
+ docs = positive + negative
135
+ is_relevant = [True]*len(positive) + [False]*len(negative)
136
+
137
+ if len(positive) == 0 or len(negative) == 0:
138
+ continue
139
+
140
+ num_queries += 1
141
+ num_positives.append(len(positive))
142
+ num_negatives.append(len(negative))
143
+
144
+ model_input = [[query, doc] for doc in docs]
145
+ pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
146
+ pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
147
+
148
+ mrr_score = 0
149
+ for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
150
+ if is_relevant[index]:
151
+ mrr_score = 1 / (rank+1)
152
+
153
+ all_mrr_scores.append(mrr_score)
154
+
155
+ mean_mrr = np.mean(all_mrr_scores)
156
+ logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
157
+ logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
158
+
159
+ if output_path is not None:
160
+ csv_path = os.path.join(output_path, self.csv_file)
161
+ output_file_exists = os.path.isfile(csv_path)
162
+ with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
163
+ writer = csv.writer(f)
164
+ if not output_file_exists:
165
+ writer.writerow(self.csv_headers)
166
+
167
+ writer.writerow([epoch, steps, mean_mrr])
168
+
169
+ return mean_mrr
170
+
171
+
172
+ evaluator = CERerankingEvaluator(dev_samples)
173
+
174
+ # Configure the training
175
+ warmup_steps = 5000
176
+ logging.info("Warmup-steps: {}".format(warmup_steps))
177
+
178
+
179
+ # Train the model
180
+ model.fit(train_dataloader=train_dataloader,
181
+ evaluator=evaluator,
182
+ epochs=num_epochs,
183
+ evaluation_steps=5000,
184
+ warmup_steps=warmup_steps,
185
+ output_path=model_save_path,
186
+ use_amp=True)
187
+
188
+ #Save latest model
189
+ model.save(model_save_path+'-latest')
190
+
191
+ # Script was called via:
192
+ #python train_cross-encoder.py nreimers/TinyBERT_L-4_H-312_v2
vocab.txt ADDED
The diff for this file is too large to render. See raw diff