Spaces:
Running
Running
Update Space (evaluate main: 828c6327)
Browse files- README.md +118 -4
- app.py +6 -0
- record_evaluation.py +111 -0
- requirements.txt +4 -0
- super_glue.py +237 -0
README.md
CHANGED
@@ -1,12 +1,126 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.0.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: SuperGLUE
|
3 |
+
emoji: 🤗
|
4 |
+
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.0.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
tags:
|
11 |
+
- evaluate
|
12 |
+
- metric
|
13 |
---
|
14 |
|
15 |
+
# Metric Card for SuperGLUE
|
16 |
+
|
17 |
+
## Metric description
|
18 |
+
This metric is used to compute the SuperGLUE evaluation metric associated to each of the subsets of the [SuperGLUE dataset](https://huggingface.co/datasets/super_glue).
|
19 |
+
|
20 |
+
SuperGLUE is a new benchmark styled after GLUE with a new set of more difficult language understanding tasks, improved resources, and a new public leaderboard.
|
21 |
+
|
22 |
+
|
23 |
+
## How to use
|
24 |
+
|
25 |
+
There are two steps: (1) loading the SuperGLUE metric relevant to the subset of the dataset being used for evaluation; and (2) calculating the metric.
|
26 |
+
|
27 |
+
1. **Loading the relevant SuperGLUE metric** : the subsets of SuperGLUE are the following: `boolq`, `cb`, `copa`, `multirc`, `record`, `rte`, `wic`, `wsc`, `wsc.fixed`, `axb`, `axg`.
|
28 |
+
|
29 |
+
More information about the different subsets of the SuperGLUE dataset can be found on the [SuperGLUE dataset page](https://huggingface.co/datasets/super_glue) and on the [official dataset website](https://super.gluebenchmark.com/).
|
30 |
+
|
31 |
+
2. **Calculating the metric**: the metric takes two inputs : one list with the predictions of the model to score and one list of reference labels. The structure of both inputs depends on the SuperGlUE subset being used:
|
32 |
+
|
33 |
+
Format of `predictions`:
|
34 |
+
- for `record`: list of question-answer dictionaries with the following keys:
|
35 |
+
- `idx`: index of the question as specified by the dataset
|
36 |
+
- `prediction_text`: the predicted answer text
|
37 |
+
- for `multirc`: list of question-answer dictionaries with the following keys:
|
38 |
+
- `idx`: index of the question-answer pair as specified by the dataset
|
39 |
+
- `prediction`: the predicted answer label
|
40 |
+
- otherwise: list of predicted labels
|
41 |
+
|
42 |
+
Format of `references`:
|
43 |
+
- for `record`: list of question-answers dictionaries with the following keys:
|
44 |
+
- `idx`: index of the question as specified by the dataset
|
45 |
+
- `answers`: list of possible answers
|
46 |
+
- otherwise: list of reference labels
|
47 |
+
|
48 |
+
```python
|
49 |
+
from evaluate import load
|
50 |
+
super_glue_metric = load('super_glue', 'copa')
|
51 |
+
predictions = [0, 1]
|
52 |
+
references = [0, 1]
|
53 |
+
results = super_glue_metric.compute(predictions=predictions, references=references)
|
54 |
+
```
|
55 |
+
## Output values
|
56 |
+
|
57 |
+
The output of the metric depends on the SuperGLUE subset chosen, consisting of a dictionary that contains one or several of the following metrics:
|
58 |
+
|
59 |
+
`exact_match`: A given predicted string's exact match score is 1 if it is the exact same as its reference string, and is 0 otherwise. (See [Exact Match](https://huggingface.co/metrics/exact_match) for more information).
|
60 |
+
|
61 |
+
`f1`: the harmonic mean of the precision and recall (see [F1 score](https://huggingface.co/metrics/f1) for more information). Its range is 0-1 -- its lowest possible value is 0, if either the precision or the recall is 0, and its highest possible value is 1.0, which means perfect precision and recall.
|
62 |
+
|
63 |
+
`matthews_correlation`: a measure of the quality of binary and multiclass classifications (see [Matthews Correlation](https://huggingface.co/metrics/matthews_correlation) for more information). Its range of values is between -1 and +1, where a coefficient of +1 represents a perfect prediction, 0 an average random prediction and -1 an inverse prediction.
|
64 |
+
|
65 |
+
### Values from popular papers
|
66 |
+
The [original SuperGLUE paper](https://arxiv.org/pdf/1905.00537.pdf) reported average scores ranging from 47 to 71.5%, depending on the model used (with all evaluation values scaled by 100 to make computing the average possible).
|
67 |
+
|
68 |
+
For more recent model performance, see the [dataset leaderboard](https://super.gluebenchmark.com/leaderboard).
|
69 |
+
|
70 |
+
## Examples
|
71 |
+
|
72 |
+
Maximal values for the COPA subset (which outputs `accuracy`):
|
73 |
+
|
74 |
+
```python
|
75 |
+
from evaluate import load
|
76 |
+
super_glue_metric = load('super_glue', 'copa') # any of ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]
|
77 |
+
predictions = [0, 1]
|
78 |
+
references = [0, 1]
|
79 |
+
results = super_glue_metric.compute(predictions=predictions, references=references)
|
80 |
+
print(results)
|
81 |
+
{'accuracy': 1.0}
|
82 |
+
```
|
83 |
+
|
84 |
+
Minimal values for the MultiRC subset (which outputs `pearson` and `spearmanr`):
|
85 |
+
|
86 |
+
```python
|
87 |
+
from evaluate import load
|
88 |
+
super_glue_metric = load('super_glue', 'multirc')
|
89 |
+
predictions = [{'idx': {'answer': 0, 'paragraph': 0, 'question': 0}, 'prediction': 0}, {'idx': {'answer': 1, 'paragraph': 2, 'question': 3}, 'prediction': 1}]
|
90 |
+
references = [1,0]
|
91 |
+
results = super_glue_metric.compute(predictions=predictions, references=references)
|
92 |
+
print(results)
|
93 |
+
{'exact_match': 0.0, 'f1_m': 0.0, 'f1_a': 0.0}
|
94 |
+
```
|
95 |
+
|
96 |
+
Partial match for the COLA subset (which outputs `matthews_correlation`)
|
97 |
+
|
98 |
+
```python
|
99 |
+
from evaluate import load
|
100 |
+
super_glue_metric = load('super_glue', 'axb')
|
101 |
+
references = [0, 1]
|
102 |
+
predictions = [1,1]
|
103 |
+
results = super_glue_metric.compute(predictions=predictions, references=references)
|
104 |
+
print(results)
|
105 |
+
{'matthews_correlation': 0.0}
|
106 |
+
```
|
107 |
+
|
108 |
+
## Limitations and bias
|
109 |
+
This metric works only with datasets that have the same format as the [SuperGLUE dataset](https://huggingface.co/datasets/super_glue).
|
110 |
+
|
111 |
+
The dataset also includes Winogender, a subset of the dataset that is designed to measure gender bias in coreference resolution systems. However, as noted in the SuperGLUE paper, this subset has its limitations: *"It offers only positive predictive value: A poor bias score is clear evidence that a model exhibits gender bias, but a good score does not mean that the model is unbiased.[...] Also, Winogender does not cover all forms of social bias, or even all forms of gender. For instance, the version of the data used here offers no coverage of gender-neutral they or non-binary pronouns."
|
112 |
+
|
113 |
+
## Citation
|
114 |
+
|
115 |
+
```bibtex
|
116 |
+
@article{wang2019superglue,
|
117 |
+
title={Super{GLUE}: A Stickier Benchmark for General-Purpose Language Understanding Systems},
|
118 |
+
author={Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R},
|
119 |
+
journal={arXiv preprint arXiv:1905.00537},
|
120 |
+
year={2019}
|
121 |
+
}
|
122 |
+
```
|
123 |
+
|
124 |
+
## Further References
|
125 |
+
|
126 |
+
- [SuperGLUE benchmark homepage](https://super.gluebenchmark.com/)
|
app.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from evaluate.utils import launch_gradio_widget
|
3 |
+
|
4 |
+
|
5 |
+
module = evaluate.load("super_glue")
|
6 |
+
launch_gradio_widget(module)
|
record_evaluation.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Official evaluation script for ReCoRD v1.0.
|
3 |
+
(Some functions are adopted from the SQuAD evaluation script.)
|
4 |
+
"""
|
5 |
+
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import re
|
10 |
+
import string
|
11 |
+
import sys
|
12 |
+
from collections import Counter
|
13 |
+
|
14 |
+
|
15 |
+
def normalize_answer(s):
|
16 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
17 |
+
|
18 |
+
def remove_articles(text):
|
19 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
20 |
+
|
21 |
+
def white_space_fix(text):
|
22 |
+
return " ".join(text.split())
|
23 |
+
|
24 |
+
def remove_punc(text):
|
25 |
+
exclude = set(string.punctuation)
|
26 |
+
return "".join(ch for ch in text if ch not in exclude)
|
27 |
+
|
28 |
+
def lower(text):
|
29 |
+
return text.lower()
|
30 |
+
|
31 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
32 |
+
|
33 |
+
|
34 |
+
def f1_score(prediction, ground_truth):
|
35 |
+
prediction_tokens = normalize_answer(prediction).split()
|
36 |
+
ground_truth_tokens = normalize_answer(ground_truth).split()
|
37 |
+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
38 |
+
num_same = sum(common.values())
|
39 |
+
if num_same == 0:
|
40 |
+
return 0
|
41 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
42 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
43 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
44 |
+
return f1
|
45 |
+
|
46 |
+
|
47 |
+
def exact_match_score(prediction, ground_truth):
|
48 |
+
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
49 |
+
|
50 |
+
|
51 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
52 |
+
scores_for_ground_truths = []
|
53 |
+
for ground_truth in ground_truths:
|
54 |
+
score = metric_fn(prediction, ground_truth)
|
55 |
+
scores_for_ground_truths.append(score)
|
56 |
+
return max(scores_for_ground_truths)
|
57 |
+
|
58 |
+
|
59 |
+
def evaluate(dataset, predictions):
|
60 |
+
f1 = exact_match = total = 0
|
61 |
+
correct_ids = []
|
62 |
+
for passage in dataset:
|
63 |
+
for qa in passage["qas"]:
|
64 |
+
total += 1
|
65 |
+
if qa["id"] not in predictions:
|
66 |
+
message = f'Unanswered question {qa["id"]} will receive score 0.'
|
67 |
+
print(message, file=sys.stderr)
|
68 |
+
continue
|
69 |
+
|
70 |
+
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
|
71 |
+
prediction = predictions[qa["id"]]
|
72 |
+
|
73 |
+
_exact_match = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
74 |
+
if int(_exact_match) == 1:
|
75 |
+
correct_ids.append(qa["id"])
|
76 |
+
exact_match += _exact_match
|
77 |
+
|
78 |
+
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
79 |
+
|
80 |
+
exact_match = exact_match / total
|
81 |
+
f1 = f1 / total
|
82 |
+
|
83 |
+
return {"exact_match": exact_match, "f1": f1}, correct_ids
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
expected_version = "1.0"
|
88 |
+
parser = argparse.ArgumentParser("Official evaluation script for ReCoRD v1.0.")
|
89 |
+
parser.add_argument("data_file", help="The dataset file in JSON format.")
|
90 |
+
parser.add_argument("pred_file", help="The model prediction file in JSON format.")
|
91 |
+
parser.add_argument("--output_correct_ids", action="store_true", help="Output the correctly answered query IDs.")
|
92 |
+
args = parser.parse_args()
|
93 |
+
|
94 |
+
with open(args.data_file) as data_file:
|
95 |
+
dataset_json = json.load(data_file)
|
96 |
+
if dataset_json["version"] != expected_version:
|
97 |
+
print(
|
98 |
+
f'Evaluation expects v-{expected_version}, but got dataset with v-{dataset_json["version"]}',
|
99 |
+
file=sys.stderr,
|
100 |
+
)
|
101 |
+
dataset = dataset_json["data"]
|
102 |
+
|
103 |
+
with open(args.pred_file) as pred_file:
|
104 |
+
predictions = json.load(pred_file)
|
105 |
+
|
106 |
+
metrics, correct_ids = evaluate(dataset, predictions)
|
107 |
+
|
108 |
+
if args.output_correct_ids:
|
109 |
+
print(f"Output {len(correct_ids)} correctly answered question IDs.")
|
110 |
+
with open("correct_ids.json", "w") as f:
|
111 |
+
json.dump(correct_ids, f)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TODO: fix github to release
|
2 |
+
git+https://github.com/huggingface/evaluate.git@b6e6ed7f3e6844b297bff1b43a1b4be0709b9671
|
3 |
+
datasets~=2.0
|
4 |
+
sklearn
|
super_glue.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Evaluate Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""The SuperGLUE benchmark metric."""
|
15 |
+
|
16 |
+
import datasets
|
17 |
+
from sklearn.metrics import f1_score, matthews_corrcoef
|
18 |
+
|
19 |
+
import evaluate
|
20 |
+
|
21 |
+
from .record_evaluation import evaluate as evaluate_record
|
22 |
+
|
23 |
+
|
24 |
+
_CITATION = """\
|
25 |
+
@article{wang2019superglue,
|
26 |
+
title={SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems},
|
27 |
+
author={Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R},
|
28 |
+
journal={arXiv preprint arXiv:1905.00537},
|
29 |
+
year={2019}
|
30 |
+
}
|
31 |
+
"""
|
32 |
+
|
33 |
+
_DESCRIPTION = """\
|
34 |
+
SuperGLUE (https://super.gluebenchmark.com/) is a new benchmark styled after
|
35 |
+
GLUE with a new set of more difficult language understanding tasks, improved
|
36 |
+
resources, and a new public leaderboard.
|
37 |
+
"""
|
38 |
+
|
39 |
+
_KWARGS_DESCRIPTION = """
|
40 |
+
Compute SuperGLUE evaluation metric associated to each SuperGLUE dataset.
|
41 |
+
Args:
|
42 |
+
predictions: list of predictions to score. Depending on the SuperGlUE subset:
|
43 |
+
- for 'record': list of question-answer dictionaries with the following keys:
|
44 |
+
- 'idx': index of the question as specified by the dataset
|
45 |
+
- 'prediction_text': the predicted answer text
|
46 |
+
- for 'multirc': list of question-answer dictionaries with the following keys:
|
47 |
+
- 'idx': index of the question-answer pair as specified by the dataset
|
48 |
+
- 'prediction': the predicted answer label
|
49 |
+
- otherwise: list of predicted labels
|
50 |
+
references: list of reference labels. Depending on the SuperGLUE subset:
|
51 |
+
- for 'record': list of question-answers dictionaries with the following keys:
|
52 |
+
- 'idx': index of the question as specified by the dataset
|
53 |
+
- 'answers': list of possible answers
|
54 |
+
- otherwise: list of reference labels
|
55 |
+
Returns: depending on the SuperGLUE subset:
|
56 |
+
- for 'record':
|
57 |
+
- 'exact_match': Exact match between answer and gold answer
|
58 |
+
- 'f1': F1 score
|
59 |
+
- for 'multirc':
|
60 |
+
- 'exact_match': Exact match between answer and gold answer
|
61 |
+
- 'f1_m': Per-question macro-F1 score
|
62 |
+
- 'f1_a': Average F1 score over all answers
|
63 |
+
- for 'axb':
|
64 |
+
'matthews_correlation': Matthew Correlation
|
65 |
+
- for 'cb':
|
66 |
+
- 'accuracy': Accuracy
|
67 |
+
- 'f1': F1 score
|
68 |
+
- for all others:
|
69 |
+
- 'accuracy': Accuracy
|
70 |
+
Examples:
|
71 |
+
|
72 |
+
>>> super_glue_metric = evaluate.load('super_glue', 'copa') # any of ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]
|
73 |
+
>>> predictions = [0, 1]
|
74 |
+
>>> references = [0, 1]
|
75 |
+
>>> results = super_glue_metric.compute(predictions=predictions, references=references)
|
76 |
+
>>> print(results)
|
77 |
+
{'accuracy': 1.0}
|
78 |
+
|
79 |
+
>>> super_glue_metric = evaluate.load('super_glue', 'cb')
|
80 |
+
>>> predictions = [0, 1]
|
81 |
+
>>> references = [0, 1]
|
82 |
+
>>> results = super_glue_metric.compute(predictions=predictions, references=references)
|
83 |
+
>>> print(results)
|
84 |
+
{'accuracy': 1.0, 'f1': 1.0}
|
85 |
+
|
86 |
+
>>> super_glue_metric = evaluate.load('super_glue', 'record')
|
87 |
+
>>> predictions = [{'idx': {'passage': 0, 'query': 0}, 'prediction_text': 'answer'}]
|
88 |
+
>>> references = [{'idx': {'passage': 0, 'query': 0}, 'answers': ['answer', 'another_answer']}]
|
89 |
+
>>> results = super_glue_metric.compute(predictions=predictions, references=references)
|
90 |
+
>>> print(results)
|
91 |
+
{'exact_match': 1.0, 'f1': 1.0}
|
92 |
+
|
93 |
+
>>> super_glue_metric = evaluate.load('super_glue', 'multirc')
|
94 |
+
>>> predictions = [{'idx': {'answer': 0, 'paragraph': 0, 'question': 0}, 'prediction': 0}, {'idx': {'answer': 1, 'paragraph': 2, 'question': 3}, 'prediction': 1}]
|
95 |
+
>>> references = [0, 1]
|
96 |
+
>>> results = super_glue_metric.compute(predictions=predictions, references=references)
|
97 |
+
>>> print(results)
|
98 |
+
{'exact_match': 1.0, 'f1_m': 1.0, 'f1_a': 1.0}
|
99 |
+
|
100 |
+
>>> super_glue_metric = evaluate.load('super_glue', 'axb')
|
101 |
+
>>> references = [0, 1]
|
102 |
+
>>> predictions = [0, 1]
|
103 |
+
>>> results = super_glue_metric.compute(predictions=predictions, references=references)
|
104 |
+
>>> print(results)
|
105 |
+
{'matthews_correlation': 1.0}
|
106 |
+
"""
|
107 |
+
|
108 |
+
|
109 |
+
def simple_accuracy(preds, labels):
|
110 |
+
return float((preds == labels).mean())
|
111 |
+
|
112 |
+
|
113 |
+
def acc_and_f1(preds, labels, f1_avg="binary"):
|
114 |
+
acc = simple_accuracy(preds, labels)
|
115 |
+
f1 = float(f1_score(y_true=labels, y_pred=preds, average=f1_avg))
|
116 |
+
return {
|
117 |
+
"accuracy": acc,
|
118 |
+
"f1": f1,
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
def evaluate_multirc(ids_preds, labels):
|
123 |
+
"""
|
124 |
+
Computes F1 score and Exact Match for MultiRC predictions.
|
125 |
+
"""
|
126 |
+
question_map = {}
|
127 |
+
for id_pred, label in zip(ids_preds, labels):
|
128 |
+
question_id = f'{id_pred["idx"]["paragraph"]}-{id_pred["idx"]["question"]}'
|
129 |
+
pred = id_pred["prediction"]
|
130 |
+
if question_id in question_map:
|
131 |
+
question_map[question_id].append((pred, label))
|
132 |
+
else:
|
133 |
+
question_map[question_id] = [(pred, label)]
|
134 |
+
f1s, ems = [], []
|
135 |
+
for question, preds_labels in question_map.items():
|
136 |
+
question_preds, question_labels = zip(*preds_labels)
|
137 |
+
f1 = f1_score(y_true=question_labels, y_pred=question_preds, average="macro")
|
138 |
+
f1s.append(f1)
|
139 |
+
em = int(sum(p == l for p, l in preds_labels) == len(preds_labels))
|
140 |
+
ems.append(em)
|
141 |
+
f1_m = float(sum(f1s) / len(f1s))
|
142 |
+
em = sum(ems) / len(ems)
|
143 |
+
f1_a = float(f1_score(y_true=labels, y_pred=[id_pred["prediction"] for id_pred in ids_preds]))
|
144 |
+
return {"exact_match": em, "f1_m": f1_m, "f1_a": f1_a}
|
145 |
+
|
146 |
+
|
147 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
148 |
+
class SuperGlue(evaluate.EvaluationModule):
|
149 |
+
def _info(self):
|
150 |
+
if self.config_name not in [
|
151 |
+
"boolq",
|
152 |
+
"cb",
|
153 |
+
"copa",
|
154 |
+
"multirc",
|
155 |
+
"record",
|
156 |
+
"rte",
|
157 |
+
"wic",
|
158 |
+
"wsc",
|
159 |
+
"wsc.fixed",
|
160 |
+
"axb",
|
161 |
+
"axg",
|
162 |
+
]:
|
163 |
+
raise KeyError(
|
164 |
+
"You should supply a configuration name selected in "
|
165 |
+
'["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'
|
166 |
+
)
|
167 |
+
return evaluate.EvaluationModuleInfo(
|
168 |
+
description=_DESCRIPTION,
|
169 |
+
citation=_CITATION,
|
170 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
171 |
+
features=datasets.Features(self._get_feature_types()),
|
172 |
+
codebase_urls=[],
|
173 |
+
reference_urls=[],
|
174 |
+
format="numpy" if not self.config_name == "record" and not self.config_name == "multirc" else None,
|
175 |
+
)
|
176 |
+
|
177 |
+
def _get_feature_types(self):
|
178 |
+
if self.config_name == "record":
|
179 |
+
return {
|
180 |
+
"predictions": {
|
181 |
+
"idx": {
|
182 |
+
"passage": datasets.Value("int64"),
|
183 |
+
"query": datasets.Value("int64"),
|
184 |
+
},
|
185 |
+
"prediction_text": datasets.Value("string"),
|
186 |
+
},
|
187 |
+
"references": {
|
188 |
+
"idx": {
|
189 |
+
"passage": datasets.Value("int64"),
|
190 |
+
"query": datasets.Value("int64"),
|
191 |
+
},
|
192 |
+
"answers": datasets.Sequence(datasets.Value("string")),
|
193 |
+
},
|
194 |
+
}
|
195 |
+
elif self.config_name == "multirc":
|
196 |
+
return {
|
197 |
+
"predictions": {
|
198 |
+
"idx": {
|
199 |
+
"answer": datasets.Value("int64"),
|
200 |
+
"paragraph": datasets.Value("int64"),
|
201 |
+
"question": datasets.Value("int64"),
|
202 |
+
},
|
203 |
+
"prediction": datasets.Value("int64"),
|
204 |
+
},
|
205 |
+
"references": datasets.Value("int64"),
|
206 |
+
}
|
207 |
+
else:
|
208 |
+
return {
|
209 |
+
"predictions": datasets.Value("int64"),
|
210 |
+
"references": datasets.Value("int64"),
|
211 |
+
}
|
212 |
+
|
213 |
+
def _compute(self, predictions, references):
|
214 |
+
if self.config_name == "axb":
|
215 |
+
return {"matthews_correlation": matthews_corrcoef(references, predictions)}
|
216 |
+
elif self.config_name == "cb":
|
217 |
+
return acc_and_f1(predictions, references, f1_avg="macro")
|
218 |
+
elif self.config_name == "record":
|
219 |
+
dataset = [
|
220 |
+
{
|
221 |
+
"qas": [
|
222 |
+
{"id": ref["idx"]["query"], "answers": [{"text": ans} for ans in ref["answers"]]}
|
223 |
+
for ref in references
|
224 |
+
]
|
225 |
+
}
|
226 |
+
]
|
227 |
+
predictions = {pred["idx"]["query"]: pred["prediction_text"] for pred in predictions}
|
228 |
+
return evaluate_record(dataset, predictions)[0]
|
229 |
+
elif self.config_name == "multirc":
|
230 |
+
return evaluate_multirc(predictions, references)
|
231 |
+
elif self.config_name in ["copa", "rte", "wic", "wsc", "wsc.fixed", "boolq", "axg"]:
|
232 |
+
return {"accuracy": simple_accuracy(predictions, references)}
|
233 |
+
else:
|
234 |
+
raise KeyError(
|
235 |
+
"You should supply a configuration name selected in "
|
236 |
+
'["boolq", "cb", "copa", "multirc", "record", "rte", "wic", "wsc", "wsc.fixed", "axb", "axg",]'
|
237 |
+
)
|