nreimers
commited on
Commit
·
e6d00b4
1
Parent(s):
f7ae382
upload
Browse files- CERerankingEvaluator_results.csv +127 -0
- README.md +35 -0
- config.json +27 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer_config.json +1 -0
- train_script.py +192 -0
- vocab.txt +0 -0
CERerankingEvaluator_results.csv
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,steps,MRR@10
|
2 |
+
0,5000,0.5220761904761905
|
3 |
+
0,10000,0.5474539682539683
|
4 |
+
0,15000,0.5190095238095239
|
5 |
+
0,20000,0.5671206349206349
|
6 |
+
0,25000,0.5630952380952381
|
7 |
+
0,30000,0.565215873015873
|
8 |
+
0,35000,0.5821079365079366
|
9 |
+
0,40000,0.5496603174603175
|
10 |
+
0,45000,0.5716761904761904
|
11 |
+
0,50000,0.5707079365079364
|
12 |
+
0,55000,0.5709047619047619
|
13 |
+
0,60000,0.5643650793650793
|
14 |
+
0,65000,0.5586507936507936
|
15 |
+
0,70000,0.5792857142857142
|
16 |
+
0,75000,0.5964857142857143
|
17 |
+
0,80000,0.5759936507936507
|
18 |
+
0,85000,0.5640507936507936
|
19 |
+
0,90000,0.6016063492063491
|
20 |
+
0,95000,0.594984126984127
|
21 |
+
0,100000,0.5770507936507936
|
22 |
+
0,105000,0.605984126984127
|
23 |
+
0,110000,0.6106380952380953
|
24 |
+
0,115000,0.5763650793650794
|
25 |
+
0,120000,0.5977269841269841
|
26 |
+
0,125000,0.5764190476190475
|
27 |
+
0,130000,0.5846825396825397
|
28 |
+
0,135000,0.5810380952380951
|
29 |
+
0,140000,0.5902317460317461
|
30 |
+
0,145000,0.6034063492063492
|
31 |
+
0,150000,0.5953714285714286
|
32 |
+
0,155000,0.5992349206349207
|
33 |
+
0,160000,0.6026666666666668
|
34 |
+
0,165000,0.6046603174603176
|
35 |
+
0,170000,0.5939269841269842
|
36 |
+
0,175000,0.6007714285714286
|
37 |
+
0,180000,0.574752380952381
|
38 |
+
0,185000,0.5923619047619049
|
39 |
+
0,190000,0.600431746031746
|
40 |
+
0,195000,0.6104984126984127
|
41 |
+
0,200000,0.6154095238095239
|
42 |
+
0,205000,0.5908285714285714
|
43 |
+
0,210000,0.590936507936508
|
44 |
+
0,215000,0.6043174603174603
|
45 |
+
0,220000,0.6032825396825396
|
46 |
+
0,225000,0.6210666666666667
|
47 |
+
0,230000,0.6113396825396825
|
48 |
+
0,235000,0.6135873015873016
|
49 |
+
0,240000,0.6162285714285715
|
50 |
+
0,245000,0.6064317460317461
|
51 |
+
0,250000,0.6072285714285715
|
52 |
+
0,255000,0.6073746031746032
|
53 |
+
0,260000,0.6112857142857142
|
54 |
+
0,265000,0.6156412698412698
|
55 |
+
0,270000,0.6350095238095238
|
56 |
+
0,275000,0.6074158730158731
|
57 |
+
0,280000,0.6154761904761905
|
58 |
+
0,285000,0.6236507936507937
|
59 |
+
0,290000,0.6162412698412697
|
60 |
+
0,295000,0.616184126984127
|
61 |
+
0,300000,0.5997523809523809
|
62 |
+
0,305000,0.5937492063492064
|
63 |
+
0,310000,0.6227968253968252
|
64 |
+
0,315000,0.6274952380952382
|
65 |
+
0,320000,0.6269523809523809
|
66 |
+
0,325000,0.6306698412698413
|
67 |
+
0,330000,0.6235079365079363
|
68 |
+
0,335000,0.6206190476190477
|
69 |
+
0,340000,0.6209936507936508
|
70 |
+
0,345000,0.613095238095238
|
71 |
+
0,350000,0.6196952380952381
|
72 |
+
0,355000,0.6197301587301588
|
73 |
+
0,360000,0.6274634920634922
|
74 |
+
0,365000,0.6152730158730159
|
75 |
+
0,370000,0.6053968253968254
|
76 |
+
0,375000,0.615352380952381
|
77 |
+
0,380000,0.6110285714285715
|
78 |
+
0,385000,0.621184126984127
|
79 |
+
0,390000,0.6025619047619047
|
80 |
+
0,395000,0.6122507936507936
|
81 |
+
0,400000,0.6189079365079365
|
82 |
+
0,405000,0.6252285714285714
|
83 |
+
0,410000,0.6022634920634921
|
84 |
+
0,415000,0.6053492063492063
|
85 |
+
0,420000,0.6239619047619047
|
86 |
+
0,425000,0.6127523809523809
|
87 |
+
0,430000,0.6231873015873016
|
88 |
+
0,435000,0.6233968253968254
|
89 |
+
0,440000,0.6186825396825396
|
90 |
+
0,445000,0.6279079365079365
|
91 |
+
0,450000,0.6075079365079366
|
92 |
+
0,455000,0.603352380952381
|
93 |
+
0,460000,0.5917142857142857
|
94 |
+
0,465000,0.5998285714285714
|
95 |
+
0,470000,0.5949492063492064
|
96 |
+
0,475000,0.6139714285714285
|
97 |
+
0,480000,0.6100507936507936
|
98 |
+
0,485000,0.6057619047619048
|
99 |
+
0,490000,0.6255714285714286
|
100 |
+
0,495000,0.6058158730158729
|
101 |
+
0,500000,0.63004126984127
|
102 |
+
0,505000,0.6207269841269841
|
103 |
+
0,510000,0.6126857142857143
|
104 |
+
0,515000,0.6224825396825397
|
105 |
+
0,520000,0.6282730158730159
|
106 |
+
0,525000,0.6256634920634919
|
107 |
+
0,530000,0.6199079365079365
|
108 |
+
0,535000,0.6065555555555556
|
109 |
+
0,540000,0.6166158730158731
|
110 |
+
0,545000,0.6133936507936507
|
111 |
+
0,550000,0.6265428571428572
|
112 |
+
0,555000,0.6077619047619048
|
113 |
+
0,560000,0.6010984126984126
|
114 |
+
0,565000,0.6134158730158731
|
115 |
+
0,570000,0.6211714285714286
|
116 |
+
0,575000,0.6167301587301588
|
117 |
+
0,580000,0.6193968253968253
|
118 |
+
0,585000,0.605352380952381
|
119 |
+
0,590000,0.6013523809523811
|
120 |
+
0,595000,0.6070285714285715
|
121 |
+
0,600000,0.6075492063492064
|
122 |
+
0,605000,0.6051396825396825
|
123 |
+
0,610000,0.609984126984127
|
124 |
+
0,615000,0.6076412698412699
|
125 |
+
0,620000,0.604384126984127
|
126 |
+
0,625000,0.6051396825396825
|
127 |
+
0,-1,0.6051396825396825
|
README.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cross-Encoder for MS Marco
|
2 |
+
|
3 |
+
This model uses [TinyBERT](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT), a tiny BERT model with only 4 layers. The base model is: General_TinyBERT_v2(4layer-312dim)
|
4 |
+
|
5 |
+
It was trained on [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
|
6 |
+
|
7 |
+
The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Information Retrieval](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/information-retrieval) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
|
8 |
+
|
9 |
+
## Usage and Performance
|
10 |
+
|
11 |
+
Pre-trained models can be used like this:
|
12 |
+
```
|
13 |
+
from sentence_transformers import CrossEncoder
|
14 |
+
model = CrossEncoder('model_name', max_length=512)
|
15 |
+
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
|
16 |
+
```
|
17 |
+
|
18 |
+
In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
|
19 |
+
|
20 |
+
|
21 |
+
| Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec (BertTokenizerFast) | Docs / Sec |
|
22 |
+
| ------------- |:-------------| -----| --- | --- |
|
23 |
+
| cross-encoder/ms-marco-TinyBERT-L-2 | 67.43 | 30.15 | 9000 | 780
|
24 |
+
| cross-encoder/ms-marco-TinyBERT-L-4 | 68.09 | 34.50 | 2900 | 760
|
25 |
+
| cross-encoder/ms-marco-TinyBERT-L-6 | 69.57 | 36.13 | 680 | 660
|
26 |
+
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340 | 340
|
27 |
+
| *Other models* | | | |
|
28 |
+
| nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900 | 760
|
29 |
+
| nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340 | 340|
|
30 |
+
| nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100 | 100 |
|
31 |
+
| Capreolus/electra-base-msmarco | 71.23 | | 340 | 340 |
|
32 |
+
| amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | | 330 | 330
|
33 |
+
|
34 |
+
Note: Runtime was computed on a V100 GPU. A bottleneck for smaller models is the standard Python tokenizer from Huggingface in version 3. Replacing it with the fast tokenizer based on Rust, the throughput is significantly improved:
|
35 |
+
|
config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "nreimers/TinyBERT_L-4_H-312_v2",
|
3 |
+
"architectures": [
|
4 |
+
"BertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"gradient_checkpointing": false,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 312,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0"
|
13 |
+
},
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 1200,
|
16 |
+
"label2id": {
|
17 |
+
"LABEL_0": 0
|
18 |
+
},
|
19 |
+
"layer_norm_eps": 1e-12,
|
20 |
+
"max_position_embeddings": 512,
|
21 |
+
"model_type": "bert",
|
22 |
+
"num_attention_heads": 12,
|
23 |
+
"num_hidden_layers": 4,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"type_vocab_size": 2,
|
26 |
+
"vocab_size": 30522
|
27 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b06785cbb737ac18adf49a4ac3ce4d724479afc9c133d5e83172b12db996dc91
|
3 |
+
size 57436041
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": "/home/ukp-reimers/.cache/torch/transformers/f96b11e14fec8f4be06121e7f6bbe07f82216bf7d75ad76fe3a81251e8895d69.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d", "tokenizer_file": null, "name_or_path": "nreimers/TinyBERT_L-4_H-312_v2"}
|
train_script.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from sentence_transformers import LoggingHandler
|
3 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
4 |
+
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
|
5 |
+
from sentence_transformers import InputExample
|
6 |
+
import logging
|
7 |
+
from datetime import datetime
|
8 |
+
import gzip
|
9 |
+
import sys
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
from shutil import copyfile
|
13 |
+
import csv
|
14 |
+
import tqdm
|
15 |
+
|
16 |
+
#### Just some code to print debug information to stdout
|
17 |
+
logging.basicConfig(format='%(asctime)s - %(message)s',
|
18 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
19 |
+
level=logging.INFO,
|
20 |
+
handlers=[LoggingHandler()])
|
21 |
+
#### /print debug information to stdout
|
22 |
+
|
23 |
+
|
24 |
+
#Define our Cross-Encoder
|
25 |
+
model_name = sys.argv[1] #'google/electra-small-discriminator'
|
26 |
+
train_batch_size = 32
|
27 |
+
num_epochs = 1
|
28 |
+
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
29 |
+
|
30 |
+
#We set num_labels=1, which predicts a continous score between 0 and 1
|
31 |
+
model = CrossEncoder(model_name, num_labels=1, max_length=512)
|
32 |
+
|
33 |
+
|
34 |
+
# Write self to path
|
35 |
+
os.makedirs(model_save_path, exist_ok=True)
|
36 |
+
|
37 |
+
train_script_path = os.path.join(model_save_path, 'train_script.py')
|
38 |
+
copyfile(__file__, train_script_path)
|
39 |
+
with open(train_script_path, 'a') as fOut:
|
40 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
41 |
+
|
42 |
+
|
43 |
+
corpus = {}
|
44 |
+
queries = {}
|
45 |
+
|
46 |
+
#### Read train file
|
47 |
+
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
|
48 |
+
for line in fIn:
|
49 |
+
pid, passage = line.strip().split("\t")
|
50 |
+
corpus[pid] = passage
|
51 |
+
|
52 |
+
with open('../data/queries.train.tsv', 'r') as fIn:
|
53 |
+
for line in fIn:
|
54 |
+
qid, query = line.strip().split("\t")
|
55 |
+
queries[qid] = query
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
pos_neg_ration = (4+1)
|
60 |
+
cnt = 0
|
61 |
+
train_samples = []
|
62 |
+
dev_samples = {}
|
63 |
+
|
64 |
+
num_dev_queries = 125
|
65 |
+
num_max_dev_negatives = 200
|
66 |
+
|
67 |
+
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
|
68 |
+
for line in fIn:
|
69 |
+
qid, pos_id, neg_id = line.strip().split()
|
70 |
+
|
71 |
+
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
|
72 |
+
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
|
73 |
+
|
74 |
+
if qid in dev_samples:
|
75 |
+
dev_samples[qid]['positive'].add(corpus[pos_id])
|
76 |
+
|
77 |
+
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
|
78 |
+
dev_samples[qid]['negative'].add(corpus[neg_id])
|
79 |
+
|
80 |
+
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
|
81 |
+
for line in tqdm.tqdm(fIn, unit_scale=True):
|
82 |
+
cnt += 1
|
83 |
+
qid, pos_id, neg_id = line.strip().split()
|
84 |
+
query = queries[qid]
|
85 |
+
if (cnt % pos_neg_ration) == 0:
|
86 |
+
passage = corpus[pos_id]
|
87 |
+
label = 1
|
88 |
+
else:
|
89 |
+
passage = corpus[neg_id]
|
90 |
+
label = 0
|
91 |
+
|
92 |
+
train_samples.append(InputExample(texts=[query, passage], label=label))
|
93 |
+
|
94 |
+
if len(train_samples) >= 2e7:
|
95 |
+
break
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
|
100 |
+
|
101 |
+
# We add an evaluator, which evaluates the performance during training
|
102 |
+
|
103 |
+
class CERerankingEvaluator:
|
104 |
+
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
|
105 |
+
self.samples = samples
|
106 |
+
self.name = name
|
107 |
+
self.mrr_at_k = mrr_at_k
|
108 |
+
|
109 |
+
if isinstance(self.samples, dict):
|
110 |
+
self.samples = list(self.samples.values())
|
111 |
+
|
112 |
+
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
|
113 |
+
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
|
114 |
+
|
115 |
+
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
|
116 |
+
if epoch != -1:
|
117 |
+
if steps == -1:
|
118 |
+
out_txt = " after epoch {}:".format(epoch)
|
119 |
+
else:
|
120 |
+
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
|
121 |
+
else:
|
122 |
+
out_txt = ":"
|
123 |
+
|
124 |
+
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
|
125 |
+
|
126 |
+
all_mrr_scores = []
|
127 |
+
num_queries = 0
|
128 |
+
num_positives = []
|
129 |
+
num_negatives = []
|
130 |
+
for instance in self.samples:
|
131 |
+
query = instance['query']
|
132 |
+
positive = list(instance['positive'])
|
133 |
+
negative = list(instance['negative'])
|
134 |
+
docs = positive + negative
|
135 |
+
is_relevant = [True]*len(positive) + [False]*len(negative)
|
136 |
+
|
137 |
+
if len(positive) == 0 or len(negative) == 0:
|
138 |
+
continue
|
139 |
+
|
140 |
+
num_queries += 1
|
141 |
+
num_positives.append(len(positive))
|
142 |
+
num_negatives.append(len(negative))
|
143 |
+
|
144 |
+
model_input = [[query, doc] for doc in docs]
|
145 |
+
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
|
146 |
+
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
|
147 |
+
|
148 |
+
mrr_score = 0
|
149 |
+
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
|
150 |
+
if is_relevant[index]:
|
151 |
+
mrr_score = 1 / (rank+1)
|
152 |
+
|
153 |
+
all_mrr_scores.append(mrr_score)
|
154 |
+
|
155 |
+
mean_mrr = np.mean(all_mrr_scores)
|
156 |
+
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
|
157 |
+
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
|
158 |
+
|
159 |
+
if output_path is not None:
|
160 |
+
csv_path = os.path.join(output_path, self.csv_file)
|
161 |
+
output_file_exists = os.path.isfile(csv_path)
|
162 |
+
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
|
163 |
+
writer = csv.writer(f)
|
164 |
+
if not output_file_exists:
|
165 |
+
writer.writerow(self.csv_headers)
|
166 |
+
|
167 |
+
writer.writerow([epoch, steps, mean_mrr])
|
168 |
+
|
169 |
+
return mean_mrr
|
170 |
+
|
171 |
+
|
172 |
+
evaluator = CERerankingEvaluator(dev_samples)
|
173 |
+
|
174 |
+
# Configure the training
|
175 |
+
warmup_steps = 5000
|
176 |
+
logging.info("Warmup-steps: {}".format(warmup_steps))
|
177 |
+
|
178 |
+
|
179 |
+
# Train the model
|
180 |
+
model.fit(train_dataloader=train_dataloader,
|
181 |
+
evaluator=evaluator,
|
182 |
+
epochs=num_epochs,
|
183 |
+
evaluation_steps=5000,
|
184 |
+
warmup_steps=warmup_steps,
|
185 |
+
output_path=model_save_path,
|
186 |
+
use_amp=True)
|
187 |
+
|
188 |
+
#Save latest model
|
189 |
+
model.save(model_save_path+'-latest')
|
190 |
+
|
191 |
+
# Script was called via:
|
192 |
+
#python train_cross-encoder.py nreimers/TinyBERT_L-4_H-312_v2
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|