Spaces:
Sleeping
Sleeping
Updated the documentation and added more test cases.
Browse files
README.md
CHANGED
@@ -25,49 +25,65 @@ summary with the reference overlap summary. It evaluates the semantic overlap su
|
|
25 |
computes precision, recall and F1 scores.
|
26 |
|
27 |
## How to Use
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
31 |
|
32 |
```python
|
33 |
from evaluate import load
|
|
|
34 |
predictions = [
|
35 |
["I go to School.", "You are stupid."],
|
36 |
["I love adventure sports."],
|
37 |
]
|
38 |
references = [
|
39 |
["I go to School.", "You are stupid."],
|
40 |
-
["I love
|
41 |
]
|
42 |
metric = load("semf1")
|
43 |
results = metric.compute(predictions=predictions, references=references)
|
|
|
|
|
44 |
```
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
`
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
[
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
[//]: # (### Inputs)
|
57 |
|
58 |
[//]: # (*List all input arguments in the format below*)
|
59 |
|
60 |
[//]: # (- **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*)
|
61 |
|
62 |
### Output Values
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
`precision`: The [precision](https://huggingface.co/metrics/precision) for each sentence from the `predictions` + `references` lists, which ranges from 0.0 to 1.0.
|
65 |
-
|
66 |
-
`recall`: The [recall](https://huggingface.co/metrics/recall) for each sentence from the `predictions` + `references` lists, which ranges from 0.0 to 1.0.
|
67 |
|
68 |
-
|
|
|
|
|
|
|
69 |
|
70 |
-
|
|
|
71 |
|
72 |
[//]: # (*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*)
|
73 |
|
|
|
25 |
computes precision, recall and F1 scores.
|
26 |
|
27 |
## How to Use
|
28 |
+
|
29 |
+
Sem-F1 takes 2 mandatory arguments:
|
30 |
+
- `predictions` - List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
31 |
+
- `references`: List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
32 |
|
33 |
```python
|
34 |
from evaluate import load
|
35 |
+
|
36 |
predictions = [
|
37 |
["I go to School.", "You are stupid."],
|
38 |
["I love adventure sports."],
|
39 |
]
|
40 |
references = [
|
41 |
["I go to School.", "You are stupid."],
|
42 |
+
["I love outdoor sports."],
|
43 |
]
|
44 |
metric = load("semf1")
|
45 |
results = metric.compute(predictions=predictions, references=references)
|
46 |
+
for score in results:
|
47 |
+
print(f"Precision: {score.precision}, Recall: {score.recall}, F1: {score.f1}")
|
48 |
```
|
49 |
|
50 |
+
Sem-F1 also accepts multiple optional arguments:
|
51 |
+
- `model_type (str)`: Model to use for encoding sentences. Options: ['pv1', 'stsb', 'use']
|
52 |
+
- `pv1` - [paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)
|
53 |
+
- `stsb` - [stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)
|
54 |
+
- `use` - [Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual) (Default)
|
55 |
+
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
56 |
+
- `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
|
57 |
+
- `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
|
58 |
+
- `batch_size (int)`: Batch size for encoding. Default: 32.
|
59 |
+
- `verbose (bool)`: Flag to indicate verbose output. Default: False.
|
60 |
+
|
61 |
+
Refer to the inputs descriptions for more detailed usage as follows
|
62 |
+
```python
|
63 |
+
import evaluate
|
64 |
+
metric = evaluate.load("semf1")
|
65 |
+
metric.inputs_description
|
66 |
+
```
|
67 |
|
|
|
68 |
|
69 |
[//]: # (*List all input arguments in the format below*)
|
70 |
|
71 |
[//]: # (- **input_field** *(type): Definition of input, with explanation if necessary. State any default value(s).*)
|
72 |
|
73 |
### Output Values
|
74 |
+
List of `Scores` dataclass corresponding to each sample -
|
75 |
+
- `precision: float`: Precision score, which ranges from 0.0 to 1.0.
|
76 |
+
- `recall: List[float]`: Recall score corresponding to each reference
|
77 |
+
- `f1: float`: F1 score (between precision and average recall).
|
78 |
|
|
|
|
|
|
|
79 |
|
80 |
+
## Future Extensions
|
81 |
+
Currently, we have only implemented the 3 encoders* that we experimented with in our
|
82 |
+
[paper](https://aclanthology.org/2022.emnlp-main.49/). However, it can easily with extended for more models by simply
|
83 |
+
extending the `Encoder` base class. (Refer to `encoder_models.py` file).
|
84 |
|
85 |
+
`*` *In out paper, we used the Tensorflow [version](https://www.tensorflow.org/hub/tutorials/semantic_similarity_with_tf_hub_universal_encoder)
|
86 |
+
of the USE model, however, in our current implementation, we used [PyTorch version](https://huggingface.co/sentence-transformers/use-cmlm-multilingual).*
|
87 |
|
88 |
[//]: # (*Give examples, preferrably with links to leaderboards or publications, to papers that have reported this metric, along with the values they have reported.*)
|
89 |
|
semf1.py
CHANGED
@@ -14,7 +14,6 @@
|
|
14 |
# TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
|
15 |
"""Sem-F1 metric"""
|
16 |
|
17 |
-
from functools import partial
|
18 |
from typing import List, Optional, Tuple
|
19 |
|
20 |
import datasets
|
@@ -56,69 +55,93 @@ sentence level and computes precision, recall and F1 scores.
|
|
56 |
"""
|
57 |
|
58 |
_KWARGS_DESCRIPTION = """
|
59 |
-
Sem-F1 compares the system
|
|
|
60 |
|
61 |
Args:
|
62 |
-
predictions
|
63 |
-
references
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
Options:
|
73 |
False - CPU (Default)
|
74 |
-
True - GPU
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
Returns:
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
85 |
references: List[List[str]] - List of references where each reference is a list of sentences.
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
Examples:
|
108 |
|
109 |
>>> import evaluate
|
110 |
>>> predictions = [
|
111 |
-
["I go to School.
|
112 |
["I love adventure sports."],
|
113 |
]
|
114 |
>>> references = [
|
115 |
-
["I go to School.
|
116 |
-
["I love
|
117 |
]
|
118 |
>>> metric = evaluate.load("semf1")
|
119 |
>>> results = metric.compute(predictions=predictions, references=references)
|
120 |
-
>>>
|
121 |
-
|
122 |
"""
|
123 |
|
124 |
|
@@ -194,7 +217,12 @@ def _validate_input_format(
|
|
194 |
- `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
|
195 |
"""
|
196 |
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
198 |
if tokenize_sentences and multi_references:
|
199 |
condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
|
200 |
elif not tokenize_sentences and multi_references:
|
@@ -225,7 +253,7 @@ class SemF1(evaluate.Metric):
|
|
225 |
inputs_description=_KWARGS_DESCRIPTION,
|
226 |
# This defines the format of each prediction and reference
|
227 |
features=[
|
228 |
-
# Multi References: False, Tokenize_Sentences = False
|
229 |
datasets.Features(
|
230 |
{
|
231 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
@@ -234,7 +262,7 @@ class SemF1(evaluate.Metric):
|
|
234 |
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
235 |
}
|
236 |
),
|
237 |
-
# Multi References: False, Tokenize_Sentences = True
|
238 |
datasets.Features(
|
239 |
{
|
240 |
# predictions: List[str] - List of predictions
|
@@ -243,7 +271,7 @@ class SemF1(evaluate.Metric):
|
|
243 |
"references": datasets.Value("string", id="sequence"),
|
244 |
}
|
245 |
),
|
246 |
-
# Multi References: True, Tokenize_Sentences = False
|
247 |
datasets.Features(
|
248 |
{
|
249 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
@@ -255,7 +283,7 @@ class SemF1(evaluate.Metric):
|
|
255 |
datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
|
256 |
}
|
257 |
),
|
258 |
-
# Multi References: True, Tokenize_Sentences = True
|
259 |
datasets.Features(
|
260 |
{
|
261 |
# predictions: List[str] - List of predictions
|
@@ -319,6 +347,12 @@ class SemF1(evaluate.Metric):
|
|
319 |
:return: List of Scores dataclass with precision, recall, and F1 scores.
|
320 |
"""
|
321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
# Validate inputs corresponding to flags
|
323 |
_validate_input_format(tokenize_sentences, multi_references, predictions, references)
|
324 |
|
@@ -363,10 +397,11 @@ class SemF1(evaluate.Metric):
|
|
363 |
# Precision: Concatenate all the sentences in all the references
|
364 |
concat_refs = np.concatenate(refs, axis=0)
|
365 |
precision, _ = _compute_cosine_similarity(preds, concat_refs)
|
|
|
366 |
|
367 |
# Recall: Compute individually for each reference
|
368 |
recall_scores = [_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs]
|
369 |
-
recall_scores = [r_scores for (r_scores, _) in recall_scores]
|
370 |
|
371 |
results.append(Scores(precision, recall_scores))
|
372 |
|
|
|
14 |
# TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
|
15 |
"""Sem-F1 metric"""
|
16 |
|
|
|
17 |
from typing import List, Optional, Tuple
|
18 |
|
19 |
import datasets
|
|
|
55 |
"""
|
56 |
|
57 |
_KWARGS_DESCRIPTION = """
|
58 |
+
Sem-F1 compares the system-generated summaries (predictions) with ground truth reference summaries (references)
|
59 |
+
using precision, recall, and F1 score based on sentence embeddings.
|
60 |
|
61 |
Args:
|
62 |
+
predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
63 |
+
references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
64 |
+
model_type (str): Model to use for encoding sentences. Options: ['pv1', 'stsb', 'use']
|
65 |
+
pv1 - paraphrase-distilroberta-base-v1 (Default)
|
66 |
+
stsb - stsb-roberta-large
|
67 |
+
use - Universal Sentence Encoder
|
68 |
+
tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
69 |
+
multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
|
70 |
+
gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
|
71 |
+
bool -
|
|
|
72 |
False - CPU (Default)
|
73 |
+
True - GPU (device 0) if gpu is available else CPU
|
74 |
+
int -
|
75 |
+
n - GPU, device index n
|
76 |
+
str -
|
77 |
+
'cuda', 'gpu', 'cpu'
|
78 |
+
List[Union[str, int]] - Multiple GPUs/cpus i.e. use multiple processes when computing embeddings
|
79 |
+
batch_size (int): Batch size for encoding. Default is 32.
|
80 |
+
verbose (bool): Flag to indicate verbose output. Default is False.
|
81 |
+
|
82 |
Returns:
|
83 |
+
List of Scores dataclass with attributes as follows -
|
84 |
+
precision: float - precision score
|
85 |
+
recall: List[float] - List of recall scores corresponding to single/multiple references
|
86 |
+
f1: float - F1 score (between precision and average recall)
|
87 |
+
|
88 |
+
Examples of input formats:
|
89 |
+
|
90 |
+
Case 1: multi_references = False, tokenize_sentences = False
|
91 |
+
predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
|
92 |
references: List[List[str]] - List of references where each reference is a list of sentences.
|
93 |
+
Example:
|
94 |
+
predictions = [["This is a prediction sentence 1.", "This is a prediction sentence 2."]]
|
95 |
+
references = [["This is a reference sentence 1.", "This is a reference sentence 2."]]
|
96 |
+
|
97 |
+
Case 2: multi_references = False, tokenize_sentences = True
|
98 |
+
predictions: List[str] - List of predictions where each prediction is a document.
|
99 |
+
references: List[str] - List of references where each reference is a document.
|
100 |
+
Example:
|
101 |
+
predictions = ["This is a prediction sentence 1. This is a prediction sentence 2."]
|
102 |
+
references = ["This is a reference sentence 1. This is a reference sentence 2."]
|
103 |
+
|
104 |
+
Case 3: multi_references = True, tokenize_sentences = False
|
105 |
+
predictions: List[List[str]] - List of predictions where each prediction is a list of sentences.
|
106 |
+
references: List[List[List[str]]] - List of references where each example has multi-references (List[r1, r2, ...])
|
107 |
+
and each ri is a List of sentences.
|
108 |
+
Example:
|
109 |
+
predictions = [["Prediction sentence 1.", "Prediction sentence 2."]]
|
110 |
+
references = [
|
111 |
+
[
|
112 |
+
["Reference sentence 1.", "Reference sentence 2."], # Reference 1
|
113 |
+
["Alternative reference 1.", "Alternative reference 2."], # Reference 2
|
114 |
+
]
|
115 |
+
]
|
116 |
+
|
117 |
+
Case 4: multi_references = True, tokenize_sentences = True
|
118 |
+
predictions: List[str] - List of predictions where each prediction is a document.
|
119 |
+
references: List[List[str]] - List of references where each example has multi-references (List[r1, r2, ...]) where
|
120 |
+
each r1 is a document.
|
121 |
+
Example:
|
122 |
+
predictions = ["Prediction sentence 1. Prediction sentence 2."]
|
123 |
+
references = [
|
124 |
+
[
|
125 |
+
"Reference sentence 1. Reference sentence 2.", # Reference 1
|
126 |
+
"Alternative reference 1. Alternative reference 2.", # Reference 2
|
127 |
+
]
|
128 |
+
]
|
129 |
+
|
130 |
Examples:
|
131 |
|
132 |
>>> import evaluate
|
133 |
>>> predictions = [
|
134 |
+
["I go to School. You are stupid."],
|
135 |
["I love adventure sports."],
|
136 |
]
|
137 |
>>> references = [
|
138 |
+
["I go to School. You are stupid."],
|
139 |
+
["I love outdoor sports."],
|
140 |
]
|
141 |
>>> metric = evaluate.load("semf1")
|
142 |
>>> results = metric.compute(predictions=predictions, references=references)
|
143 |
+
>>> for score in results:
|
144 |
+
>>> print(f"Precision: {score.precision}, Recall: {score.recall}, F1: {score.f1}")
|
145 |
"""
|
146 |
|
147 |
|
|
|
217 |
- `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
|
218 |
"""
|
219 |
|
220 |
+
if len(predictions) != len(references):
|
221 |
+
raise ValueError("Predictions and references must have the same length.")
|
222 |
+
|
223 |
+
def is_list_of_strings_at_depth(lst_obj, depth: int):
|
224 |
+
return is_nested_list_of_type(lst_obj, element_type=str, depth=depth)
|
225 |
+
|
226 |
if tokenize_sentences and multi_references:
|
227 |
condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
|
228 |
elif not tokenize_sentences and multi_references:
|
|
|
253 |
inputs_description=_KWARGS_DESCRIPTION,
|
254 |
# This defines the format of each prediction and reference
|
255 |
features=[
|
256 |
+
# F0: Multi References: False, Tokenize_Sentences = False
|
257 |
datasets.Features(
|
258 |
{
|
259 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
|
|
262 |
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
263 |
}
|
264 |
),
|
265 |
+
# F1: Multi References: False, Tokenize_Sentences = True
|
266 |
datasets.Features(
|
267 |
{
|
268 |
# predictions: List[str] - List of predictions
|
|
|
271 |
"references": datasets.Value("string", id="sequence"),
|
272 |
}
|
273 |
),
|
274 |
+
# F2: Multi References: True, Tokenize_Sentences = False
|
275 |
datasets.Features(
|
276 |
{
|
277 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
|
|
283 |
datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
|
284 |
}
|
285 |
),
|
286 |
+
# F3: Multi References: True, Tokenize_Sentences = True
|
287 |
datasets.Features(
|
288 |
{
|
289 |
# predictions: List[str] - List of predictions
|
|
|
347 |
:return: List of Scores dataclass with precision, recall, and F1 scores.
|
348 |
"""
|
349 |
|
350 |
+
# Note: I have to specifically handle this case because the library considers the feature corresponding to
|
351 |
+
# this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
|
352 |
+
# List[str] and List[List[str]]
|
353 |
+
if not tokenize_sentences and multi_references:
|
354 |
+
references = [[eval(ref) for ref in mul_ref_ex] for mul_ref_ex in references]
|
355 |
+
|
356 |
# Validate inputs corresponding to flags
|
357 |
_validate_input_format(tokenize_sentences, multi_references, predictions, references)
|
358 |
|
|
|
397 |
# Precision: Concatenate all the sentences in all the references
|
398 |
concat_refs = np.concatenate(refs, axis=0)
|
399 |
precision, _ = _compute_cosine_similarity(preds, concat_refs)
|
400 |
+
precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
|
401 |
|
402 |
# Recall: Compute individually for each reference
|
403 |
recall_scores = [_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs]
|
404 |
+
recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
|
405 |
|
406 |
results.append(Scores(precision, recall_scores))
|
407 |
|
tests.py
CHANGED
@@ -3,9 +3,12 @@ import unittest
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
|
|
7 |
|
8 |
from encoder_models import SBertEncoder, get_encoder
|
|
|
9 |
from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
|
10 |
|
11 |
|
@@ -178,5 +181,321 @@ class TestGetEncoder(unittest.TestCase):
|
|
178 |
# self.assertEqual(encoder.verbose, verbose)
|
179 |
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
if __name__ == '__main__':
|
182 |
-
unittest.main()
|
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
+
from numpy.testing import assert_almost_equal
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
|
10 |
from encoder_models import SBertEncoder, get_encoder
|
11 |
+
from semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
12 |
from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
|
13 |
|
14 |
|
|
|
181 |
# self.assertEqual(encoder.verbose, verbose)
|
182 |
|
183 |
|
184 |
+
class TestSemF1(unittest.TestCase):
|
185 |
+
def setUp(self):
|
186 |
+
self.semf1_metric = SemF1() # semf1_metric
|
187 |
+
|
188 |
+
# Example cases, #Samples = 1
|
189 |
+
self.untokenized_single_reference_predictions = [
|
190 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."]
|
191 |
+
self.untokenized_single_reference_references = [
|
192 |
+
"This is a reference sentence 1. This is a reference sentence 2."]
|
193 |
+
|
194 |
+
self.tokenized_single_reference_predictions = [
|
195 |
+
["This is a prediction sentence 1.", "This is a prediction sentence 2."],
|
196 |
+
]
|
197 |
+
self.tokenized_single_reference_references = [
|
198 |
+
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
199 |
+
]
|
200 |
+
|
201 |
+
self.untokenized_multi_reference_predictions = [
|
202 |
+
"Prediction sentence 1. Prediction sentence 2."
|
203 |
+
]
|
204 |
+
self.untokenized_multi_reference_references = [
|
205 |
+
["Reference sentence 1. Reference sentence 2.", "Alternative reference 1. Alternative reference 2."],
|
206 |
+
]
|
207 |
+
|
208 |
+
self.tokenized_multi_reference_predictions = [
|
209 |
+
["Prediction sentence 1.", "Prediction sentence 2."],
|
210 |
+
]
|
211 |
+
self.tokenized_multi_reference_references = [
|
212 |
+
[
|
213 |
+
["Reference sentence 1.", "Reference sentence 2."],
|
214 |
+
["Alternative reference 1.", "Alternative reference 2."]
|
215 |
+
],
|
216 |
+
]
|
217 |
+
|
218 |
+
def test_untokenized_single_reference(self):
|
219 |
+
scores = self.semf1_metric.compute(
|
220 |
+
predictions=self.untokenized_single_reference_predictions,
|
221 |
+
references=self.untokenized_single_reference_references,
|
222 |
+
tokenize_sentences=True,
|
223 |
+
multi_references=False,
|
224 |
+
gpu=False,
|
225 |
+
batch_size=32,
|
226 |
+
verbose=False
|
227 |
+
)
|
228 |
+
self.assertIsInstance(scores, list)
|
229 |
+
self.assertEqual(len(scores), len(self.untokenized_single_reference_predictions))
|
230 |
+
|
231 |
+
def test_tokenized_single_reference(self):
|
232 |
+
scores = self.semf1_metric.compute(
|
233 |
+
predictions=self.tokenized_single_reference_predictions,
|
234 |
+
references=self.tokenized_single_reference_references,
|
235 |
+
tokenize_sentences=False,
|
236 |
+
multi_references=False,
|
237 |
+
gpu=False,
|
238 |
+
batch_size=32,
|
239 |
+
verbose=False
|
240 |
+
)
|
241 |
+
self.assertIsInstance(scores, list)
|
242 |
+
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
243 |
+
|
244 |
+
for score in scores:
|
245 |
+
self.assertIsInstance(score, Scores)
|
246 |
+
self.assertTrue(0.0 <= score.precision <= 1.0)
|
247 |
+
self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
|
248 |
+
|
249 |
+
def test_untokenized_multi_reference(self):
|
250 |
+
scores = self.semf1_metric.compute(
|
251 |
+
predictions=self.untokenized_multi_reference_predictions,
|
252 |
+
references=self.untokenized_multi_reference_references,
|
253 |
+
tokenize_sentences=True,
|
254 |
+
multi_references=True,
|
255 |
+
gpu=False,
|
256 |
+
batch_size=32,
|
257 |
+
verbose=False
|
258 |
+
)
|
259 |
+
self.assertIsInstance(scores, list)
|
260 |
+
self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
|
261 |
+
|
262 |
+
def test_tokenized_multi_reference(self):
|
263 |
+
scores = self.semf1_metric.compute(
|
264 |
+
predictions=self.tokenized_multi_reference_predictions,
|
265 |
+
references=self.tokenized_multi_reference_references,
|
266 |
+
tokenize_sentences=False,
|
267 |
+
multi_references=True,
|
268 |
+
gpu=False,
|
269 |
+
batch_size=32,
|
270 |
+
verbose=False
|
271 |
+
)
|
272 |
+
self.assertIsInstance(scores, list)
|
273 |
+
self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
|
274 |
+
|
275 |
+
for score in scores:
|
276 |
+
self.assertIsInstance(score, Scores)
|
277 |
+
self.assertTrue(0.0 <= score.precision <= 1.0)
|
278 |
+
self.assertTrue(all(0.0 <= recall <= 1.0 for recall in score.recall))
|
279 |
+
|
280 |
+
def test_same_predictions_and_references(self):
|
281 |
+
scores = self.semf1_metric.compute(
|
282 |
+
predictions=self.tokenized_single_reference_predictions,
|
283 |
+
references=self.tokenized_single_reference_predictions,
|
284 |
+
tokenize_sentences=False,
|
285 |
+
multi_references=False,
|
286 |
+
gpu=False,
|
287 |
+
batch_size=32,
|
288 |
+
verbose=False
|
289 |
+
)
|
290 |
+
|
291 |
+
self.assertIsInstance(scores, list)
|
292 |
+
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
293 |
+
|
294 |
+
for score in scores:
|
295 |
+
self.assertIsInstance(score, Scores)
|
296 |
+
self.assertAlmostEqual(score.precision, 1.0, places=6)
|
297 |
+
assert_almost_equal(score.recall, 1, decimal=5, err_msg="Not all values are almost equal to 1")
|
298 |
+
|
299 |
+
def test_exact_output_scores(self):
|
300 |
+
predictions = [
|
301 |
+
["I go to School.", "You are stupid."],
|
302 |
+
["I love adventure sports."],
|
303 |
+
]
|
304 |
+
references = [
|
305 |
+
["I go to playground.", "You are genius.", "You need to be admired."],
|
306 |
+
["I love adventure sports."],
|
307 |
+
]
|
308 |
+
scores = self.semf1_metric.compute(
|
309 |
+
predictions=predictions,
|
310 |
+
references=references,
|
311 |
+
tokenize_sentences=False,
|
312 |
+
multi_references=False,
|
313 |
+
gpu=False,
|
314 |
+
batch_size=32,
|
315 |
+
verbose=False,
|
316 |
+
model_type="use",
|
317 |
+
)
|
318 |
+
|
319 |
+
self.assertIsInstance(scores, list)
|
320 |
+
self.assertEqual(len(scores), len(predictions))
|
321 |
+
|
322 |
+
score = scores[0]
|
323 |
+
self.assertIsInstance(score, Scores)
|
324 |
+
self.assertAlmostEqual(score.precision, 0.73, places=2)
|
325 |
+
self.assertAlmostEqual(score.recall[0], 0.63, places=2)
|
326 |
+
|
327 |
+
|
328 |
+
class TestCosineSimilarity(unittest.TestCase):
|
329 |
+
|
330 |
+
def setUp(self):
|
331 |
+
# Sample embeddings for testing
|
332 |
+
self.pred_embeds = np.array([
|
333 |
+
[1, 0, 0],
|
334 |
+
[0, 1, 0],
|
335 |
+
[0, 0, 1]
|
336 |
+
])
|
337 |
+
self.ref_embeds = np.array([
|
338 |
+
[1, 0, 0],
|
339 |
+
[0, 1, 0],
|
340 |
+
[0, 0, 1]
|
341 |
+
])
|
342 |
+
|
343 |
+
self.pred_embeds_random = np.random.rand(3, 3)
|
344 |
+
self.ref_embeds_random = np.random.rand(3, 3)
|
345 |
+
|
346 |
+
def test_cosine_similarity_perfect_match(self):
|
347 |
+
precision, recall = _compute_cosine_similarity(self.pred_embeds, self.ref_embeds)
|
348 |
+
|
349 |
+
# Expected values are 1.0 for both precision and recall since embeddings are identical
|
350 |
+
self.assertAlmostEqual(precision, 1.0, places=5)
|
351 |
+
self.assertAlmostEqual(recall, 1.0, places=5)
|
352 |
+
|
353 |
+
def _test_cosine_similarity_base(self, pred_embeds, ref_embeds):
|
354 |
+
precision, recall = _compute_cosine_similarity(pred_embeds, ref_embeds)
|
355 |
+
|
356 |
+
# Calculate expected precision and recall using sklearn's cosine similarity function
|
357 |
+
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
|
358 |
+
expected_precision = np.mean(np.max(cosine_scores, axis=-1)).item()
|
359 |
+
expected_recall = np.mean(np.max(cosine_scores, axis=0)).item()
|
360 |
+
|
361 |
+
self.assertAlmostEqual(precision, expected_precision, places=5)
|
362 |
+
self.assertAlmostEqual(recall, expected_recall, places=5)
|
363 |
+
|
364 |
+
def test_cosine_similarity_random(self):
|
365 |
+
self._test_cosine_similarity_base(self.pred_embeds_random, self.ref_embeds_random)
|
366 |
+
|
367 |
+
def test_cosine_similarity_different_shapes(self):
|
368 |
+
pred_embeds_diff = np.random.rand(5, 3)
|
369 |
+
ref_embeds_diff = np.random.rand(3, 3)
|
370 |
+
self._test_cosine_similarity_base(pred_embeds_diff, ref_embeds_diff)
|
371 |
+
|
372 |
+
|
373 |
+
class TestValidateInputFormat(unittest.TestCase):
|
374 |
+
def setUp(self):
|
375 |
+
# Sample predictions and references for different scenarios where number of samples = 1
|
376 |
+
# Note: Naming Convention: # When tokenize_sentences = True (i.e. input is untokenized) and vice-versa
|
377 |
+
|
378 |
+
# When tokenize_sentences = True (untokenized input) and multi_references = False
|
379 |
+
self.untokenized_single_reference_predictions = [
|
380 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."
|
381 |
+
]
|
382 |
+
self.untokenized_single_reference_references = [
|
383 |
+
"This is a reference sentence 1. This is a reference sentence 2."
|
384 |
+
]
|
385 |
+
|
386 |
+
# When tokenize_sentences = False (tokenized input) and multi_references = False
|
387 |
+
self.tokenized_single_reference_predictions = [
|
388 |
+
["This is a prediction sentence 1.", "This is a prediction sentence 2."]
|
389 |
+
]
|
390 |
+
self.tokenized_single_reference_references = [
|
391 |
+
["This is a reference sentence 1.", "This is a reference sentence 2."]
|
392 |
+
]
|
393 |
+
|
394 |
+
# When tokenize_sentences = True (untokenized input) and multi_references = True
|
395 |
+
self.untokenized_multi_reference_predictions = [
|
396 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."
|
397 |
+
]
|
398 |
+
self.untokenized_multi_reference_references = [
|
399 |
+
[
|
400 |
+
"This is a reference sentence 1. This is a reference sentence 2.",
|
401 |
+
"Another reference sentence."
|
402 |
+
]
|
403 |
+
]
|
404 |
+
|
405 |
+
# When tokenize_sentences = False (tokenized input) and multi_references = True
|
406 |
+
self.tokenized_multi_reference_predictions = [
|
407 |
+
["This is a prediction sentence 1.", "This is a prediction sentence 2."]
|
408 |
+
]
|
409 |
+
self.tokenized_multi_reference_references = [
|
410 |
+
[
|
411 |
+
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
412 |
+
["Another reference sentence."]
|
413 |
+
]
|
414 |
+
]
|
415 |
+
|
416 |
+
def test_tokenized_sentences_true_multi_references_true(self):
|
417 |
+
# Invalid format should raise an error
|
418 |
+
with self.assertRaises(ValueError):
|
419 |
+
_validate_input_format(
|
420 |
+
True,
|
421 |
+
True,
|
422 |
+
self.tokenized_single_reference_predictions,
|
423 |
+
self.tokenized_single_reference_references,
|
424 |
+
)
|
425 |
+
|
426 |
+
# Valid format should pass without error
|
427 |
+
_validate_input_format(
|
428 |
+
True,
|
429 |
+
True,
|
430 |
+
self.untokenized_multi_reference_predictions,
|
431 |
+
self.untokenized_multi_reference_references,
|
432 |
+
)
|
433 |
+
|
434 |
+
def test_tokenized_sentences_false_multi_references_true(self):
|
435 |
+
# Invalid format should raise an error
|
436 |
+
with self.assertRaises(ValueError):
|
437 |
+
_validate_input_format(
|
438 |
+
False,
|
439 |
+
True,
|
440 |
+
self.untokenized_single_reference_predictions,
|
441 |
+
self.untokenized_multi_reference_references,
|
442 |
+
)
|
443 |
+
|
444 |
+
# Valid format should pass without error
|
445 |
+
_validate_input_format(
|
446 |
+
False,
|
447 |
+
True,
|
448 |
+
self.tokenized_multi_reference_predictions,
|
449 |
+
self.tokenized_multi_reference_references,
|
450 |
+
)
|
451 |
+
|
452 |
+
def test_tokenized_sentences_true_multi_references_false(self):
|
453 |
+
# Invalid format should raise an error
|
454 |
+
with self.assertRaises(ValueError):
|
455 |
+
_validate_input_format(
|
456 |
+
True,
|
457 |
+
False,
|
458 |
+
self.tokenized_single_reference_predictions,
|
459 |
+
self.tokenized_single_reference_references,
|
460 |
+
)
|
461 |
+
|
462 |
+
# Valid format should pass without error
|
463 |
+
_validate_input_format(
|
464 |
+
True,
|
465 |
+
False,
|
466 |
+
self.untokenized_single_reference_predictions,
|
467 |
+
self.untokenized_single_reference_references,
|
468 |
+
)
|
469 |
+
|
470 |
+
def test_tokenized_sentences_false_multi_references_false(self):
|
471 |
+
# Invalid format should raise an error
|
472 |
+
with self.assertRaises(ValueError):
|
473 |
+
_validate_input_format(
|
474 |
+
False,
|
475 |
+
False,
|
476 |
+
self.untokenized_single_reference_predictions,
|
477 |
+
self.untokenized_single_reference_references,
|
478 |
+
)
|
479 |
+
|
480 |
+
# Valid format should pass without error
|
481 |
+
_validate_input_format(
|
482 |
+
False,
|
483 |
+
False,
|
484 |
+
self.tokenized_single_reference_predictions,
|
485 |
+
self.tokenized_single_reference_references,
|
486 |
+
)
|
487 |
+
|
488 |
+
def test_mismatched_lengths(self):
|
489 |
+
# Length mismatch should raise an error
|
490 |
+
with self.assertRaises(ValueError):
|
491 |
+
_validate_input_format(
|
492 |
+
True,
|
493 |
+
True,
|
494 |
+
self.untokenized_single_reference_predictions,
|
495 |
+
[self.untokenized_single_reference_predictions[0], self.untokenized_single_reference_predictions[0]],
|
496 |
+
)
|
497 |
+
|
498 |
+
|
499 |
if __name__ == '__main__':
|
500 |
+
unittest.main(verbosity=2)
|
501 |
+
# unittest.main()
|