lvwerra HF staff commited on
Commit
f39d195
1 Parent(s): 69152e8

Update Space (evaluate main: 828c6327)

Browse files
Files changed (5) hide show
  1. README.md +118 -4
  2. app.py +6 -0
  3. record_evaluation.py +111 -0
  4. requirements.txt +4 -0
  5. super_glue.py +237 -0
README.md CHANGED
@@ -1,12 +1,126 @@
1
  ---
2
- title: Super_glue
3
- emoji: 👀
4
- colorFrom: yellow
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.0.2
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SuperGLUE
3
+ emoji: 🤗
4
+ colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.0.2
8
  app_file: app.py
9
  pinned: false
10
+ tags:
11
+ - evaluate
12
+ - metric
13
  ---
14
 
15
+ # Metric Card for SuperGLUE
16
+
17
+ ## Metric description
18
+ This metric is used to compute the SuperGLUE evaluation metric associated to each of the subsets of the [SuperGLUE dataset](https://huggingface.co/datasets/super_glue).
19
+
20
+ SuperGLUE is a new benchmark styled after GLUE with a new set of more difficult language understanding tasks, improved resources, and a new public leaderboard.
21
+
22
+
23
+ ## How to use
24
+
25
+ There are two steps: (1) loading the SuperGLUE metric relevant to the subset of the dataset being used for evaluation; and (2) calculating the metric.
26
+
27
+ 1. **Loading the relevant SuperGLUE metric** : the subsets of SuperGLUE are the following: `boolq`, `cb`, `copa`, `multirc`, `record`, `rte`, `wic`, `wsc`, `wsc.fixed`, `axb`, `axg`.
28
+
29
+ More information about the different subsets of the SuperGLUE dataset can be found on the [SuperGLUE dataset page](https://huggingface.co/datasets/super_glue) and on the [official dataset website](https://super.gluebenchmark.com/).
30
+
31
+ 2. **Calculating the metric**: the metric takes two inputs : one list with the predictions of the model to score and one list of reference labels. The structure of both inputs depends on the SuperGlUE subset being used:
32
+
33
+ Format of `predictions`:
34
+ - for `record`: list of question-answer dictionaries with the following keys:
35
+ - `idx`: index of the question as specified by the dataset
36
+ - `prediction_text`: the predicted answer text
37
+ - for `multirc`: list of question-answer dictionaries with the following keys:
38
+ - `idx`: index of the question-answer pair as specified by the dataset
39
+ - `prediction`: the predicted answer label
40
+ - otherwise: list of predicted labels
41
+
42
+ Format of `references`:
43
+ - for `record`: list of question-answers dictionaries with the following keys:
44
+ - `idx`: index of the question as specified by the dataset
45
+ - `answers`: list of possible answers
46
+ - otherwise: list of reference labels
47
+
48
+ ```python
49
+ from evaluate import load
50
+ super_glue_metric = load('super_glue', 'copa')
51
+ predictions = [0, 1]
52
+ references = [0, 1]
53
+ results = super_glue_metric.compute(predictions=predictions, references=references)
54
+ ```
55
+ ## Output values
56
+
57
+ The output of the metric depends on the SuperGLUE subset chosen, consisting of a dictionary that contains one or several of the following metrics:
58
+
59
+ `exact_match`: A given predicted string's exact match score is 1 if it is the exact same as its reference string, and is 0 otherwise. (See [Exact Match](https://huggingface.co/metrics/exact_match) for more information).
60
+
61
+ `f1`: the harmonic mean of the precision and recall (see [F1 score](https://huggingface.co/metrics/f1) for more information). Its range is 0-1 -- its lowest possible value is 0, if either the precision or the recall is 0, and its highest possible value is 1.0, which means perfect precision and recall.
62
+
63
+ `matthews_correlation`: a measure of the quality of binary and multiclass classifications (see [Matthews Correlation](https://huggingface.co/metrics/matthews_correlation) for more information). Its range of values is between -1 and +1, where a coefficient of +1 represents a perfect prediction, 0 an average random prediction and -1 an inverse prediction.
64
+
65
+ ### Values from popular papers
66
+ The [original SuperGLUE paper](https://arxiv.org/pdf/1905.00537.pdf) reported average scores ranging from 47 to 71.5%, depending on the model used (with all evaluation values scaled by 100 to make computing the average possible).
67
+
68
+ For more recent model performance, see the [dataset leaderboard](https://super.gluebenchmark.com/leaderboard).
69
+
70
+ ## Examples
71
+
72
+ Maximal values for the COPA subset (which outputs `accuracy`):
73
+
74
+ ```python
75
+ from evaluate import load
76
+ super_glue_metric = load('super_glue', 'copa') # any of ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]
77
+ predictions = [0, 1]
78
+ references = [0, 1]
79
+ results = super_glue_metric.compute(predictions=predictions, references=references)
80
+ print(results)
81
+ {'accuracy': 1.0}
82
+ ```
83
+
84
+ Minimal values for the MultiRC subset (which outputs `pearson` and `spearmanr`):
85
+
86
+ ```python
87
+ from evaluate import load
88
+ super_glue_metric = load('super_glue', 'multirc')
89
+ predictions = [{'idx': {'answer': 0, 'paragraph': 0, 'question': 0}, 'prediction': 0}, {'idx': {'answer': 1, 'paragraph': 2, 'question': 3}, 'prediction': 1}]
90
+ references = [1,0]
91
+ results = super_glue_metric.compute(predictions=predictions, references=references)
92
+ print(results)
93
+ {'exact_match': 0.0, 'f1_m': 0.0, 'f1_a': 0.0}
94
+ ```
95
+
96
+ Partial match for the COLA subset (which outputs `matthews_correlation`)
97
+
98
+ ```python
99
+ from evaluate import load
100
+ super_glue_metric = load('super_glue', 'axb')
101
+ references = [0, 1]
102
+ predictions = [1,1]
103
+ results = super_glue_metric.compute(predictions=predictions, references=references)
104
+ print(results)
105
+ {'matthews_correlation': 0.0}
106
+ ```
107
+
108
+ ## Limitations and bias
109
+ This metric works only with datasets that have the same format as the [SuperGLUE dataset](https://huggingface.co/datasets/super_glue).
110
+
111
+ The dataset also includes Winogender, a subset of the dataset that is designed to measure gender bias in coreference resolution systems. However, as noted in the SuperGLUE paper, this subset has its limitations: *"It offers only positive predictive value: A poor bias score is clear evidence that a model exhibits gender bias, but a good score does not mean that the model is unbiased.[...] Also, Winogender does not cover all forms of social bias, or even all forms of gender. For instance, the version of the data used here offers no coverage of gender-neutral they or non-binary pronouns."
112
+
113
+ ## Citation
114
+
115
+ ```bibtex
116
+ @article{wang2019superglue,
117
+ title={Super{GLUE}: A Stickier Benchmark for General-Purpose Language Understanding Systems},
118
+ author={Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R},
119
+ journal={arXiv preprint arXiv:1905.00537},
120
+ year={2019}
121
+ }
122
+ ```
123
+
124
+ ## Further References
125
+
126
+ - [SuperGLUE benchmark homepage](https://super.gluebenchmark.com/)
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+
5
+ module = evaluate.load("super_glue")
6
+ launch_gradio_widget(module)
record_evaluation.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Official evaluation script for ReCoRD v1.0.
3
+ (Some functions are adopted from the SQuAD evaluation script.)
4
+ """
5
+
6
+
7
+ import argparse
8
+ import json
9
+ import re
10
+ import string
11
+ import sys
12
+ from collections import Counter
13
+
14
+
15
+ def normalize_answer(s):
16
+ """Lower text and remove punctuation, articles and extra whitespace."""
17
+
18
+ def remove_articles(text):
19
+ return re.sub(r"\b(a|an|the)\b", " ", text)
20
+
21
+ def white_space_fix(text):
22
+ return " ".join(text.split())
23
+
24
+ def remove_punc(text):
25
+ exclude = set(string.punctuation)
26
+ return "".join(ch for ch in text if ch not in exclude)
27
+
28
+ def lower(text):
29
+ return text.lower()
30
+
31
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
32
+
33
+
34
+ def f1_score(prediction, ground_truth):
35
+ prediction_tokens = normalize_answer(prediction).split()
36
+ ground_truth_tokens = normalize_answer(ground_truth).split()
37
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
38
+ num_same = sum(common.values())
39
+ if num_same == 0:
40
+ return 0
41
+ precision = 1.0 * num_same / len(prediction_tokens)
42
+ recall = 1.0 * num_same / len(ground_truth_tokens)
43
+ f1 = (2 * precision * recall) / (precision + recall)
44
+ return f1
45
+
46
+
47
+ def exact_match_score(prediction, ground_truth):
48
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
49
+
50
+
51
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
52
+ scores_for_ground_truths = []
53
+ for ground_truth in ground_truths:
54
+ score = metric_fn(prediction, ground_truth)
55
+ scores_for_ground_truths.append(score)
56
+ return max(scores_for_ground_truths)
57
+
58
+
59
+ def evaluate(dataset, predictions):
60
+ f1 = exact_match = total = 0
61
+ correct_ids = []
62
+ for passage in dataset:
63
+ for qa in passage["qas"]:
64
+ total += 1
65
+ if qa["id"] not in predictions:
66
+ message = f'Unanswered question {qa["id"]} will receive score 0.'
67
+ print(message, file=sys.stderr)
68
+ continue
69
+
70
+ ground_truths = list(map(lambda x: x["text"], qa["answers"]))
71
+ prediction = predictions[qa["id"]]
72
+
73
+ _exact_match = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
74
+ if int(_exact_match) == 1:
75
+ correct_ids.append(qa["id"])
76
+ exact_match += _exact_match
77
+
78
+ f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
79
+
80
+ exact_match = exact_match / total
81
+ f1 = f1 / total
82
+
83
+ return {"exact_match": exact_match, "f1": f1}, correct_ids
84
+
85
+
86
+ if __name__ == "__main__":
87
+ expected_version = "1.0"
88
+ parser = argparse.ArgumentParser("Official evaluation script for ReCoRD v1.0.")
89
+ parser.add_argument("data_file", help="The dataset file in JSON format.")
90
+ parser.add_argument("pred_file", help="The model prediction file in JSON format.")
91
+ parser.add_argument("--output_correct_ids", action="store_true", help="Output the correctly answered query IDs.")
92
+ args = parser.parse_args()
93
+
94
+ with open(args.data_file) as data_file:
95
+ dataset_json = json.load(data_file)
96
+ if dataset_json["version"] != expected_version:
97
+ print(
98
+ f'Evaluation expects v-{expected_version}, but got dataset with v-{dataset_json["version"]}',
99
+ file=sys.stderr,
100
+ )
101
+ dataset = dataset_json["data"]
102
+
103
+ with open(args.pred_file) as pred_file:
104
+ predictions = json.load(pred_file)
105
+
106
+ metrics, correct_ids = evaluate(dataset, predictions)
107
+
108
+ if args.output_correct_ids:
109
+ print(f"Output {len(correct_ids)} correctly answered question IDs.")
110
+ with open("correct_ids.json", "w") as f:
111
+ json.dump(correct_ids, f)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # TODO: fix github to release
2
+ git+https://github.com/huggingface/evaluate.git@b6e6ed7f3e6844b297bff1b43a1b4be0709b9671
3
+ datasets~=2.0
4
+ sklearn
super_glue.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Evaluate Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """The SuperGLUE benchmark metric."""
15
+
16
+ import datasets
17
+ from sklearn.metrics import f1_score, matthews_corrcoef
18
+
19
+ import evaluate
20
+
21
+ from .record_evaluation import evaluate as evaluate_record
22
+
23
+
24
+ _CITATION = """\
25
+ @article{wang2019superglue,
26
+ title={SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems},
27
+ author={Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R},
28
+ journal={arXiv preprint arXiv:1905.00537},
29
+ year={2019}
30
+ }
31
+ """
32
+
33
+ _DESCRIPTION = """\
34
+ SuperGLUE (https://super.gluebenchmark.com/) is a new benchmark styled after
35
+ GLUE with a new set of more difficult language understanding tasks, improved
36
+ resources, and a new public leaderboard.
37
+ """
38
+
39
+ _KWARGS_DESCRIPTION = """
40
+ Compute SuperGLUE evaluation metric associated to each SuperGLUE dataset.
41
+ Args:
42
+ predictions: list of predictions to score. Depending on the SuperGlUE subset:
43
+ - for 'record': list of question-answer dictionaries with the following keys:
44
+ - 'idx': index of the question as specified by the dataset
45
+ - 'prediction_text': the predicted answer text
46
+ - for 'multirc': list of question-answer dictionaries with the following keys:
47
+ - 'idx': index of the question-answer pair as specified by the dataset
48
+ - 'prediction': the predicted answer label
49
+ - otherwise: list of predicted labels
50
+ references: list of reference labels. Depending on the SuperGLUE subset:
51
+ - for 'record': list of question-answers dictionaries with the following keys:
52
+ - 'idx': index of the question as specified by the dataset
53
+ - 'answers': list of possible answers
54
+ - otherwise: list of reference labels
55
+ Returns: depending on the SuperGLUE subset:
56
+ - for 'record':
57
+ - 'exact_match': Exact match between answer and gold answer
58
+ - 'f1': F1 score
59
+ - for 'multirc':
60
+ - 'exact_match': Exact match between answer and gold answer
61
+ - 'f1_m': Per-question macro-F1 score
62
+ - 'f1_a': Average F1 score over all answers
63
+ - for 'axb':
64
+ 'matthews_correlation': Matthew Correlation
65
+ - for 'cb':
66
+ - 'accuracy': Accuracy
67
+ - 'f1': F1 score
68
+ - for all others:
69
+ - 'accuracy': Accuracy
70
+ Examples:
71
+
72
+ >>> super_glue_metric = evaluate.load('super_glue', 'copa') # any of ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]
73
+ >>> predictions = [0, 1]
74
+ >>> references = [0, 1]
75
+ >>> results = super_glue_metric.compute(predictions=predictions, references=references)
76
+ >>> print(results)
77
+ {'accuracy': 1.0}
78
+
79
+ >>> super_glue_metric = evaluate.load('super_glue', 'cb')
80
+ >>> predictions = [0, 1]
81
+ >>> references = [0, 1]
82
+ >>> results = super_glue_metric.compute(predictions=predictions, references=references)
83
+ >>> print(results)
84
+ {'accuracy': 1.0, 'f1': 1.0}
85
+
86
+ >>> super_glue_metric = evaluate.load('super_glue', 'record')
87
+ >>> predictions = [{'idx': {'passage': 0, 'query': 0}, 'prediction_text': 'answer'}]
88
+ >>> references = [{'idx': {'passage': 0, 'query': 0}, 'answers': ['answer', 'another_answer']}]
89
+ >>> results = super_glue_metric.compute(predictions=predictions, references=references)
90
+ >>> print(results)
91
+ {'exact_match': 1.0, 'f1': 1.0}
92
+
93
+ >>> super_glue_metric = evaluate.load('super_glue', 'multirc')
94
+ >>> predictions = [{'idx': {'answer': 0, 'paragraph': 0, 'question': 0}, 'prediction': 0}, {'idx': {'answer': 1, 'paragraph': 2, 'question': 3}, 'prediction': 1}]
95
+ >>> references = [0, 1]
96
+ >>> results = super_glue_metric.compute(predictions=predictions, references=references)
97
+ >>> print(results)
98
+ {'exact_match': 1.0, 'f1_m': 1.0, 'f1_a': 1.0}
99
+
100
+ >>> super_glue_metric = evaluate.load('super_glue', 'axb')
101
+ >>> references = [0, 1]
102
+ >>> predictions = [0, 1]
103
+ >>> results = super_glue_metric.compute(predictions=predictions, references=references)
104
+ >>> print(results)
105
+ {'matthews_correlation': 1.0}
106
+ """
107
+
108
+
109
+ def simple_accuracy(preds, labels):
110
+ return float((preds == labels).mean())
111
+
112
+
113
+ def acc_and_f1(preds, labels, f1_avg="binary"):
114
+ acc = simple_accuracy(preds, labels)
115
+ f1 = float(f1_score(y_true=labels, y_pred=preds, average=f1_avg))
116
+ return {
117
+ "accuracy": acc,
118
+ "f1": f1,
119
+ }
120
+
121
+
122
+ def evaluate_multirc(ids_preds, labels):
123
+ """
124
+ Computes F1 score and Exact Match for MultiRC predictions.
125
+ """
126
+ question_map = {}
127
+ for id_pred, label in zip(ids_preds, labels):
128
+ question_id = f'{id_pred["idx"]["paragraph"]}-{id_pred["idx"]["question"]}'
129
+ pred = id_pred["prediction"]
130
+ if question_id in question_map:
131
+ question_map[question_id].append((pred, label))
132
+ else:
133
+ question_map[question_id] = [(pred, label)]
134
+ f1s, ems = [], []
135
+ for question, preds_labels in question_map.items():
136
+ question_preds, question_labels = zip(*preds_labels)
137
+ f1 = f1_score(y_true=question_labels, y_pred=question_preds, average="macro")
138
+ f1s.append(f1)
139
+ em = int(sum(p == l for p, l in preds_labels) == len(preds_labels))
140
+ ems.append(em)
141
+ f1_m = float(sum(f1s) / len(f1s))
142
+ em = sum(ems) / len(ems)
143
+ f1_a = float(f1_score(y_true=labels, y_pred=[id_pred["prediction"] for id_pred in ids_preds]))
144
+ return {"exact_match": em, "f1_m": f1_m, "f1_a": f1_a}
145
+
146
+
147
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
148
+ class SuperGlue(evaluate.EvaluationModule):
149
+ def _info(self):
150
+ if self.config_name not in [
151
+ "boolq",
152
+ "cb",
153
+ "copa",
154
+ "multirc",
155
+ "record",
156
+ "rte",
157
+ "wic",
158
+ "wsc",
159
+ "wsc.fixed",
160
+ "axb",
161
+ "axg",
162
+ ]:
163
+ raise KeyError(
164
+ "You should supply a configuration name selected in "
165
+ '["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'
166
+ )
167
+ return evaluate.EvaluationModuleInfo(
168
+ description=_DESCRIPTION,
169
+ citation=_CITATION,
170
+ inputs_description=_KWARGS_DESCRIPTION,
171
+ features=datasets.Features(self._get_feature_types()),
172
+ codebase_urls=[],
173
+ reference_urls=[],
174
+ format="numpy" if not self.config_name == "record" and not self.config_name == "multirc" else None,
175
+ )
176
+
177
+ def _get_feature_types(self):
178
+ if self.config_name == "record":
179
+ return {
180
+ "predictions": {
181
+ "idx": {
182
+ "passage": datasets.Value("int64"),
183
+ "query": datasets.Value("int64"),
184
+ },
185
+ "prediction_text": datasets.Value("string"),
186
+ },
187
+ "references": {
188
+ "idx": {
189
+ "passage": datasets.Value("int64"),
190
+ "query": datasets.Value("int64"),
191
+ },
192
+ "answers": datasets.Sequence(datasets.Value("string")),
193
+ },
194
+ }
195
+ elif self.config_name == "multirc":
196
+ return {
197
+ "predictions": {
198
+ "idx": {
199
+ "answer": datasets.Value("int64"),
200
+ "paragraph": datasets.Value("int64"),
201
+ "question": datasets.Value("int64"),
202
+ },
203
+ "prediction": datasets.Value("int64"),
204
+ },
205
+ "references": datasets.Value("int64"),
206
+ }
207
+ else:
208
+ return {
209
+ "predictions": datasets.Value("int64"),
210
+ "references": datasets.Value("int64"),
211
+ }
212
+
213
+ def _compute(self, predictions, references):
214
+ if self.config_name == "axb":
215
+ return {"matthews_correlation": matthews_corrcoef(references, predictions)}
216
+ elif self.config_name == "cb":
217
+ return acc_and_f1(predictions, references, f1_avg="macro")
218
+ elif self.config_name == "record":
219
+ dataset = [
220
+ {
221
+ "qas": [
222
+ {"id": ref["idx"]["query"], "answers": [{"text": ans} for ans in ref["answers"]]}
223
+ for ref in references
224
+ ]
225
+ }
226
+ ]
227
+ predictions = {pred["idx"]["query"]: pred["prediction_text"] for pred in predictions}
228
+ return evaluate_record(dataset, predictions)[0]
229
+ elif self.config_name == "multirc":
230
+ return evaluate_multirc(predictions, references)
231
+ elif self.config_name in ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]:
232
+ return {"accuracy": simple_accuracy(predictions, references)}
233
+ else:
234
+ raise KeyError(
235
+ "You should supply a configuration name selected in "
236
+ '["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'
237
+ )