fschlatt commited on
Commit
d93bc17
·
1 Parent(s): ebf03b2

initial commit

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. ner_eval.py +668 -40
  3. tests.py +0 -17
  4. tests/test_ner_eval.py +319 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
ner_eval.py CHANGED
@@ -11,24 +11,29 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- """TODO: Add a description here."""
15
 
16
- import evaluate
17
- import datasets
 
18
 
 
 
19
 
20
  # TODO: Add BibTeX citation
21
  _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
 
 
26
  }
27
  """
28
 
29
  # TODO: Add description of the module here
30
  _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
 
32
  """
33
 
34
 
@@ -36,49 +41,166 @@ This new module is designed to solve this great ML task and is crafted with a lo
36
  _KWARGS_DESCRIPTION = """
37
  Calculates how good are predictions given some references, using certain scores
38
  Args:
39
- predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
- references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
-
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
52
  >>> print(results)
53
- {'accuracy': 1.0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
-
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
- class ner_eval(evaluate.Metric):
62
  """TODO: Short description of my evaluation module."""
63
 
64
  def _info(self):
65
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
66
  return evaluate.MetricInfo(
67
  # This is the description that will appear on the modules page.
68
  module_type="metric",
69
  description=_DESCRIPTION,
70
  citation=_CITATION,
 
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
77
- # Homepage of the module for documentation
78
- homepage="http://module.homepage",
 
 
 
 
79
  # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
 
 
 
82
  )
83
 
84
  def _download_and_prepare(self, dl_manager):
@@ -86,10 +208,516 @@ class ner_eval(evaluate.Metric):
86
  # TODO: Download external resources if needed
87
  pass
88
 
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
 
15
+ from collections import namedtuple
16
+ from copy import deepcopy
17
+ from typing import Sequence, Optional
18
 
19
+ import datasets
20
+ import evaluate
21
 
22
  # TODO: Add BibTeX citation
23
  _CITATION = """\
24
+ @misc{nereval,
25
+ title={{NER-Evaluation}: Named Entity Evaluation as in SemEval 2013 task 9.1},
26
+ url={https://github.com/davidsbatista/NER-Evaluation},
27
+ note={Software available from https://github.com/davidsbatista/NER-Evaluation},
28
+ author={Batista David},
29
+ year={2018},
30
  }
31
  """
32
 
33
  # TODO: Add description of the module here
34
  _DESCRIPTION = """\
35
+ ner-eval is a Python frame for sequence labeling evaluation. I twas used in SemEval 2013 task 9.1.
36
+ It supports exact match, partial match, spurious and other errors.
37
  """
38
 
39
 
 
41
  _KWARGS_DESCRIPTION = """
42
  Calculates how good are predictions given some references, using certain scores
43
  Args:
44
+ predictions: List of List of predicted labels (Estimated targets as returned by a tagger)
45
+ references: List of List of reference labels (Ground truth (correct) target values)
46
+ tags: List of tags to evaluate. default: None
 
47
  Returns:
48
+ 'scores' dict. Summary of the scores for overall and each tag.
49
+ {
50
+ "overall": {
51
+ "strict_precision": 0.0,
52
+ "strict_recall": 0.0,
53
+ "strict_f1": 0,
54
+ "ent_type_precision": 0.0,
55
+ "ent_type_recall": 0.0,
56
+ "ent_type_f1": 0,
57
+ "partial_precision": 0.0,
58
+ "partial_recall": 0.0,
59
+ "partial_f1": 0,
60
+ "exact_precision": 0.0,
61
+ "exact_recall": 0.0,
62
+ "exact_f1": 0,
63
+ },
64
+ "ORG": {
65
+ "strict_precision": 0.0,
66
+ "strict_recall": 0.0,
67
+ "strict_f1": 0,
68
+ "ent_type_precision": 0.0,
69
+ "ent_type_recall": 0.0,
70
+ "ent_type_f1": 0,
71
+ "partial_precision": 0.0,
72
+ "partial_recall": 0.0,
73
+ "partial_f1": 0,
74
+ "exact_precision": 0.0,
75
+ "exact_recall": 0.0,
76
+ "exact_f1": 0,
77
+ },
78
+ "PER": {
79
+ "strict_precision": 0.0,
80
+ "strict_recall": 0.0,
81
+ "strict_f1": 0,
82
+ "ent_type_precision": 0.0,
83
+ "ent_type_recall": 0.0,
84
+ "ent_type_f1": 0,
85
+ "partial_precision": 0.0,
86
+ "partial_recall": 0.0,
87
+ "partial_f1": 0,
88
+ "exact_precision": 0.0,
89
+ "exact_recall": 0.0,
90
+ "exact_f1": 0,
91
+ },
92
+ "LOC": {
93
+ "strict_precision": 0.0,
94
+ "strict_recall": 0.0,
95
+ "strict_f1": 0,
96
+ "ent_type_precision": 0.0,
97
+ "ent_type_recall": 0.0,
98
+ "ent_type_f1": 0,
99
+ "partial_precision": 0.0,
100
+ "partial_recall": 0.0,
101
+ "partial_f1": 0,
102
+ "exact_precision": 0.0,
103
+ "exact_recall": 0.0,
104
+ "exact_f1": 0,
105
+ },
106
+ }
107
  Examples:
108
+ >>> my_new_module = evaluate.load("fschlatt/ner_eval")
109
+ >>> results = my_new_module.compute(
110
+ ... references=[["B-LOC", "I-LOC", "I-LOC", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "I-PER", "O"]],
111
+ ... predictions=[["B-LOC", "I-LOC", "O", "O", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "O"]]
112
+ ... )
113
  >>> print(results)
114
+ {
115
+ "overall": {
116
+ "strict_precision": 0.0,
117
+ "strict_recall": 0.0,
118
+ "strict_f1": 0,
119
+ "ent_type_precision": 2 / 3,
120
+ "ent_type_recall": 2 / 3,
121
+ "ent_type_f1": 2 / 3,
122
+ "partial_precision": 1 / 3,
123
+ "partial_recall": 1 / 3,
124
+ "partial_f1": 1 / 3,
125
+ "exact_precision": 0.0,
126
+ "exact_recall": 0.0,
127
+ "exact_f1": 0,
128
+ },
129
+ "ORG": {
130
+ "strict_precision": 0.0,
131
+ "strict_recall": 0.0,
132
+ "strict_f1": 0,
133
+ "ent_type_precision": 0.0,
134
+ "ent_type_recall": 0.0,
135
+ "ent_type_f1": 0,
136
+ "partial_precision": 0.0,
137
+ "partial_recall": 0.0,
138
+ "partial_f1": 0,
139
+ "exact_precision": 0.0,
140
+ "exact_recall": 0.0,
141
+ "exact_f1": 0,
142
+ },
143
+ "PER": {
144
+ "strict_precision": 0.0,
145
+ "strict_recall": 0.0,
146
+ "strict_f1": 0,
147
+ "ent_type_precision": 0.5,
148
+ "ent_type_recall": 1.0,
149
+ "ent_type_f1": 2 / 3,
150
+ "partial_precision": 0.25,
151
+ "partial_recall": 0.5,
152
+ "partial_f1": 1 / 3,
153
+ "exact_precision": 0.0,
154
+ "exact_recall": 0.0,
155
+ "exact_f1": 0,
156
+ },
157
+ "LOC": {
158
+ "strict_precision": 0.0,
159
+ "strict_recall": 0.0,
160
+ "strict_f1": 0,
161
+ "ent_type_precision": 0.5,
162
+ "ent_type_recall": 1.0,
163
+ "ent_type_f1": 2 / 3,
164
+ "partial_precision": 0.25,
165
+ "partial_recall": 0.5,
166
+ "partial_f1": 1 / 3,
167
+ "exact_precision": 0.0,
168
+ "exact_recall": 0.0,
169
+ "exact_f1": 0,
170
+ }
171
+ }
172
  """
173
 
 
 
 
174
 
175
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
176
+ class NEREval(evaluate.Metric):
177
  """TODO: Short description of my evaluation module."""
178
 
179
  def _info(self):
 
180
  return evaluate.MetricInfo(
181
  # This is the description that will appear on the modules page.
182
  module_type="metric",
183
  description=_DESCRIPTION,
184
  citation=_CITATION,
185
+ homepage="https://github.com/davidsbatista/NER-Evaluation",
186
  inputs_description=_KWARGS_DESCRIPTION,
187
  # This defines the format of each prediction and reference
188
+ features=datasets.Features(
189
+ {
190
+ "predictions": datasets.Sequence(
191
+ datasets.Value("string", id="label"), id="sequence"
192
+ ),
193
+ "references": datasets.Sequence(
194
+ datasets.Value("string", id="label"), id="sequence"
195
+ ),
196
+ }
197
+ ),
198
  # Additional links to the codebase or references
199
+ codebase_urls=["https://github.com/davidsbatista/NER-Evaluation"],
200
+ reference_urls=[
201
+ "https://github.com/davidsbatista/NER-Evaluation",
202
+ "https://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/",
203
+ ],
204
  )
205
 
206
  def _download_and_prepare(self, dl_manager):
 
208
  # TODO: Download external resources if needed
209
  pass
210
 
211
+ def _compute(
212
+ self,
213
+ predictions: Sequence[Sequence[str]],
214
+ references: Sequence[Sequence[str]],
215
+ tags: Optional[Sequence[str]] = None,
216
+ modes: Optional[Sequence[str]] = None,
217
+ ):
218
+ if tags is None:
219
+ tags = list(parse_tags(predictions).union(parse_tags(references)))
220
+
221
+ evaluator = Evaluator(predictions, references, tags)
222
+ results, agg_results = evaluator.evaluate()
223
+
224
+ out = {"overall": parse_results(results, modes)}
225
+ for tag, tag_result in agg_results.items():
226
+ out = {**out, tag: parse_results(tag_result, modes)}
227
+
228
+ return out
229
+
230
+
231
+ def parse_results(results, modes: Optional[Sequence[str]] = None):
232
+ if modes is None:
233
+ modes = ["strict", "ent_type", "partial", "exact"]
234
+
235
+ out = {}
236
+ for mode in modes:
237
+ out[f"{mode}_precision"] = results[mode]["precision"]
238
+ out[f"{mode}_recall"] = results[mode]["recall"]
239
+ out[f"{mode}_f1"] = results[mode]["f1"]
240
+ return out
241
+
242
+
243
+ def parse_tags(tokens: Sequence[Sequence[str]]):
244
+ tags = set()
245
+ for seq in tokens:
246
+ for t in seq:
247
+ tags.add(t.split("-")[-1])
248
+ tags.discard("O")
249
+ return tags
250
+
251
+
252
+ Entity = namedtuple("Entity", "e_type start_offset end_offset")
253
+
254
+
255
+ class Evaluator:
256
+ def __init__(self, true, pred, tags):
257
+ """ """
258
+
259
+ if len(true) != len(pred):
260
+ raise ValueError("Number of predicted documents does not equal true")
261
+
262
+ self.true = true
263
+ self.pred = pred
264
+ self.tags = tags
265
+
266
+ # Setup dict into which metrics will be stored.
267
+
268
+ self.metrics_results = {
269
+ "correct": 0,
270
+ "incorrect": 0,
271
+ "partial": 0,
272
+ "missed": 0,
273
+ "spurious": 0,
274
+ "possible": 0,
275
+ "actual": 0,
276
+ "precision": 0,
277
+ "recall": 0,
278
+ }
279
+
280
+ # Copy results dict to cover the four schemes.
281
+
282
+ self.results = {
283
+ "strict": deepcopy(self.metrics_results),
284
+ "ent_type": deepcopy(self.metrics_results),
285
+ "partial": deepcopy(self.metrics_results),
286
+ "exact": deepcopy(self.metrics_results),
287
+ }
288
+
289
+ # Create an accumulator to store results
290
+
291
+ self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags}
292
+
293
+ def evaluate(self):
294
+ for true_ents, pred_ents in zip(self.true, self.pred):
295
+ # Check that the length of the true and predicted examples are the
296
+ # same. This must be checked here, because another error may not
297
+ # be thrown if the lengths do not match.
298
+
299
+ if len(true_ents) != len(pred_ents):
300
+ raise ValueError("Prediction length does not match true example length")
301
+
302
+ # Compute results for one message
303
+
304
+ tmp_results, tmp_agg_results = compute_metrics(
305
+ collect_named_entities(true_ents),
306
+ collect_named_entities(pred_ents),
307
+ self.tags,
308
+ )
309
+
310
+ # Cycle through each result and accumulate
311
+
312
+ # TODO: Combine these loops below:
313
+
314
+ for eval_schema in self.results:
315
+ for metric in self.results[eval_schema]:
316
+ self.results[eval_schema][metric] += tmp_results[eval_schema][
317
+ metric
318
+ ]
319
+
320
+ # Calculate global precision and recall
321
+
322
+ self.results = compute_precision_recall_f1_wrapper(self.results)
323
+
324
+ # Aggregate results by entity type
325
+
326
+ for e_type in self.tags:
327
+ for eval_schema in tmp_agg_results[e_type]:
328
+ for metric in tmp_agg_results[e_type][eval_schema]:
329
+ self.evaluation_agg_entities_type[e_type][eval_schema][
330
+ metric
331
+ ] += tmp_agg_results[e_type][eval_schema][metric]
332
+
333
+ # Calculate precision recall at the individual entity level
334
+
335
+ self.evaluation_agg_entities_type[
336
+ e_type
337
+ ] = compute_precision_recall_f1_wrapper(
338
+ self.evaluation_agg_entities_type[e_type]
339
+ )
340
+
341
+ return self.results, self.evaluation_agg_entities_type
342
+
343
+
344
+ def collect_named_entities(tokens):
345
+ """
346
+ Creates a list of Entity named-tuples, storing the entity type and the start and end
347
+ offsets of the entity.
348
+
349
+ :param tokens: a list of tags
350
+ :return: a list of Entity named-tuples
351
+ """
352
+
353
+ named_entities = []
354
+ start_offset = None
355
+ end_offset = None
356
+ ent_type = None
357
+
358
+ for offset, token_tag in enumerate(tokens):
359
+ if token_tag == "O":
360
+ if ent_type is not None and start_offset is not None:
361
+ end_offset = offset - 1
362
+ named_entities.append(Entity(ent_type, start_offset, end_offset))
363
+ start_offset = None
364
+ end_offset = None
365
+ ent_type = None
366
+
367
+ elif ent_type is None:
368
+ ent_type = token_tag[2:]
369
+ start_offset = offset
370
+
371
+ elif ent_type != token_tag[2:] or (
372
+ ent_type == token_tag[2:] and token_tag[:1] == "B"
373
+ ):
374
+ end_offset = offset - 1
375
+ named_entities.append(Entity(ent_type, start_offset, end_offset))
376
+
377
+ # start of a new entity
378
+ ent_type = token_tag[2:]
379
+ start_offset = offset
380
+ end_offset = None
381
+
382
+ # catches an entity that goes up until the last token
383
+
384
+ if ent_type is not None and start_offset is not None and end_offset is None:
385
+ named_entities.append(Entity(ent_type, start_offset, len(tokens) - 1))
386
+
387
+ return named_entities
388
+
389
+
390
+ def compute_metrics(true_named_entities, pred_named_entities, tags):
391
+ eval_metrics = {
392
+ "correct": 0,
393
+ "incorrect": 0,
394
+ "partial": 0,
395
+ "missed": 0,
396
+ "spurious": 0,
397
+ "precision": 0,
398
+ "recall": 0,
399
+ }
400
+
401
+ # overall results
402
+
403
+ evaluation = {
404
+ "strict": deepcopy(eval_metrics),
405
+ "ent_type": deepcopy(eval_metrics),
406
+ "partial": deepcopy(eval_metrics),
407
+ "exact": deepcopy(eval_metrics),
408
+ }
409
+
410
+ # results by entity type
411
+
412
+ evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags}
413
+
414
+ # keep track of entities that overlapped
415
+
416
+ true_which_overlapped_with_pred = []
417
+
418
+ # Subset into only the tags that we are interested in.
419
+ # NOTE: we remove the tags we don't want from both the predicted and the
420
+ # true entities. This covers the two cases where mismatches can occur:
421
+ #
422
+ # 1) Where the model predicts a tag that is not present in the true data
423
+ # 2) Where there is a tag in the true data that the model is not capable of
424
+ # predicting.
425
+
426
+ true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags]
427
+ pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags]
428
+
429
+ # go through each predicted named-entity
430
+
431
+ for pred in pred_named_entities:
432
+ found_overlap = False
433
+
434
+ # Check each of the potential scenarios in turn. See
435
+ # http://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/
436
+ # for scenario explanation.
437
+
438
+ # Scenario I: Exact match between true and pred
439
+
440
+ if pred in true_named_entities:
441
+ true_which_overlapped_with_pred.append(pred)
442
+ evaluation["strict"]["correct"] += 1
443
+ evaluation["ent_type"]["correct"] += 1
444
+ evaluation["exact"]["correct"] += 1
445
+ evaluation["partial"]["correct"] += 1
446
+
447
+ # for the agg. by e_type results
448
+ evaluation_agg_entities_type[pred.e_type]["strict"]["correct"] += 1
449
+ evaluation_agg_entities_type[pred.e_type]["ent_type"]["correct"] += 1
450
+ evaluation_agg_entities_type[pred.e_type]["exact"]["correct"] += 1
451
+ evaluation_agg_entities_type[pred.e_type]["partial"]["correct"] += 1
452
+
453
+ else:
454
+ # check for overlaps with any of the true entities
455
+
456
+ for true in true_named_entities:
457
+ pred_range = range(pred.start_offset, pred.end_offset)
458
+ true_range = range(true.start_offset, true.end_offset)
459
+
460
+ # Scenario IV: Offsets match, but entity type is wrong
461
+
462
+ if (
463
+ true.start_offset == pred.start_offset
464
+ and pred.end_offset == true.end_offset
465
+ and true.e_type != pred.e_type
466
+ ):
467
+ # overall results
468
+ evaluation["strict"]["incorrect"] += 1
469
+ evaluation["ent_type"]["incorrect"] += 1
470
+ evaluation["partial"]["correct"] += 1
471
+ evaluation["exact"]["correct"] += 1
472
+
473
+ # aggregated by entity type results
474
+ evaluation_agg_entities_type[true.e_type]["strict"][
475
+ "incorrect"
476
+ ] += 1
477
+ evaluation_agg_entities_type[true.e_type]["ent_type"][
478
+ "incorrect"
479
+ ] += 1
480
+ evaluation_agg_entities_type[true.e_type]["partial"]["correct"] += 1
481
+ evaluation_agg_entities_type[true.e_type]["exact"]["correct"] += 1
482
+
483
+ true_which_overlapped_with_pred.append(true)
484
+ found_overlap = True
485
+
486
+ break
487
+
488
+ # check for an overlap i.e. not exact boundary match, with true entities
489
+
490
+ elif find_overlap(true_range, pred_range):
491
+ true_which_overlapped_with_pred.append(true)
492
+
493
+ # Scenario V: There is an overlap (but offsets do not match
494
+ # exactly), and the entity type is the same.
495
+ # 2.1 overlaps with the same entity type
496
+
497
+ if pred.e_type == true.e_type:
498
+ # overall results
499
+ evaluation["strict"]["incorrect"] += 1
500
+ evaluation["ent_type"]["correct"] += 1
501
+ evaluation["partial"]["partial"] += 1
502
+ evaluation["exact"]["incorrect"] += 1
503
+
504
+ # aggregated by entity type results
505
+ evaluation_agg_entities_type[true.e_type]["strict"][
506
+ "incorrect"
507
+ ] += 1
508
+ evaluation_agg_entities_type[true.e_type]["ent_type"][
509
+ "correct"
510
+ ] += 1
511
+ evaluation_agg_entities_type[true.e_type]["partial"][
512
+ "partial"
513
+ ] += 1
514
+ evaluation_agg_entities_type[true.e_type]["exact"][
515
+ "incorrect"
516
+ ] += 1
517
+
518
+ found_overlap = True
519
+
520
+ break
521
+
522
+ # Scenario VI: Entities overlap, but the entity type is
523
+ # different.
524
+
525
+ else:
526
+ # overall results
527
+ evaluation["strict"]["incorrect"] += 1
528
+ evaluation["ent_type"]["incorrect"] += 1
529
+ evaluation["partial"]["partial"] += 1
530
+ evaluation["exact"]["incorrect"] += 1
531
+
532
+ # aggregated by entity type results
533
+ # Results against the true entity
534
+
535
+ evaluation_agg_entities_type[true.e_type]["strict"][
536
+ "incorrect"
537
+ ] += 1
538
+ evaluation_agg_entities_type[true.e_type]["partial"][
539
+ "partial"
540
+ ] += 1
541
+ evaluation_agg_entities_type[true.e_type]["ent_type"][
542
+ "incorrect"
543
+ ] += 1
544
+ evaluation_agg_entities_type[true.e_type]["exact"][
545
+ "incorrect"
546
+ ] += 1
547
+
548
+ # Results against the predicted entity
549
+
550
+ # evaluation_agg_entities_type[pred.e_type]['strict']['spurious'] += 1
551
+
552
+ found_overlap = True
553
+
554
+ break
555
+
556
+ # Scenario II: Entities are spurious (i.e., over-generated).
557
+
558
+ if not found_overlap:
559
+ # Overall results
560
+
561
+ evaluation["strict"]["spurious"] += 1
562
+ evaluation["ent_type"]["spurious"] += 1
563
+ evaluation["partial"]["spurious"] += 1
564
+ evaluation["exact"]["spurious"] += 1
565
+
566
+ # Aggregated by entity type results
567
+
568
+ # NOTE: when pred.e_type is not found in tags
569
+ # or when it simply does not appear in the test set, then it is
570
+ # spurious, but it is not clear where to assign it at the tag
571
+ # level. In this case, it is applied to all target_tags
572
+ # found in this example. This will mean that the sum of the
573
+ # evaluation_agg_entities will not equal evaluation.
574
+
575
+ for true in tags:
576
+ evaluation_agg_entities_type[true]["strict"]["spurious"] += 1
577
+ evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1
578
+ evaluation_agg_entities_type[true]["partial"]["spurious"] += 1
579
+ evaluation_agg_entities_type[true]["exact"]["spurious"] += 1
580
+
581
+ # Scenario III: Entity was missed entirely.
582
+
583
+ for true in true_named_entities:
584
+ if true in true_which_overlapped_with_pred:
585
+ continue
586
+ else:
587
+ # overall results
588
+ evaluation["strict"]["missed"] += 1
589
+ evaluation["ent_type"]["missed"] += 1
590
+ evaluation["partial"]["missed"] += 1
591
+ evaluation["exact"]["missed"] += 1
592
+
593
+ # for the agg. by e_type
594
+ evaluation_agg_entities_type[true.e_type]["strict"]["missed"] += 1
595
+ evaluation_agg_entities_type[true.e_type]["ent_type"]["missed"] += 1
596
+ evaluation_agg_entities_type[true.e_type]["partial"]["missed"] += 1
597
+ evaluation_agg_entities_type[true.e_type]["exact"]["missed"] += 1
598
+
599
+ # Compute 'possible', 'actual' according to SemEval-2013 Task 9.1 on the
600
+ # overall results, and use these to calculate precision and recall.
601
+
602
+ for eval_type in evaluation:
603
+ evaluation[eval_type] = compute_actual_possible(evaluation[eval_type])
604
+
605
+ # Compute 'possible', 'actual', and precision and recall on entity level
606
+ # results. Start by cycling through the accumulated results.
607
+
608
+ for entity_type, entity_level in evaluation_agg_entities_type.items():
609
+ # Cycle through the evaluation types for each dict containing entity
610
+ # level results.
611
+
612
+ for eval_type in entity_level:
613
+ evaluation_agg_entities_type[entity_type][
614
+ eval_type
615
+ ] = compute_actual_possible(entity_level[eval_type])
616
+
617
+ return evaluation, evaluation_agg_entities_type
618
+
619
+
620
+ def find_overlap(true_range, pred_range):
621
+ """Find the overlap between two ranges
622
+
623
+ Find the overlap between two ranges. Return the overlapping values if
624
+ present, else return an empty set().
625
+
626
+ Examples:
627
+
628
+ >>> find_overlap((1, 2), (2, 3))
629
+ 2
630
+ >>> find_overlap((1, 2), (3, 4))
631
+ set()
632
+ """
633
+
634
+ true_set = set(true_range)
635
+ pred_set = set(pred_range)
636
+
637
+ overlaps = true_set.intersection(pred_set)
638
+
639
+ return overlaps
640
+
641
+
642
+ def compute_actual_possible(results):
643
+ """
644
+ Takes a result dict that has been output by compute metrics.
645
+ Returns the results dict with actual, possible populated.
646
+
647
+ When the results dicts is from partial or ent_type metrics, then
648
+ partial_or_type=True to ensure the right calculation is used for
649
+ calculating precision and recall.
650
+ """
651
+
652
+ correct = results["correct"]
653
+ incorrect = results["incorrect"]
654
+ partial = results["partial"]
655
+ missed = results["missed"]
656
+ spurious = results["spurious"]
657
+
658
+ # Possible: number annotations in the gold-standard which contribute to the
659
+ # final score
660
+
661
+ possible = correct + incorrect + partial + missed
662
+
663
+ # Actual: number of annotations produced by the NER system
664
+
665
+ actual = correct + incorrect + partial + spurious
666
+
667
+ results["actual"] = actual
668
+ results["possible"] = possible
669
+
670
+ return results
671
+
672
+
673
+ def compute_precision_recall_f1(results, partial_or_type=False):
674
+ """
675
+ Takes a result dict that has been output by compute metrics.
676
+ Returns the results dict with precison and recall populated.
677
+
678
+ When the results dicts is from partial or ent_type metrics, then
679
+ partial_or_type=True to ensure the right calculation is used for
680
+ calculating precision and recall.
681
+ """
682
+
683
+ actual = results["actual"]
684
+ possible = results["possible"]
685
+ partial = results["partial"]
686
+ correct = results["correct"]
687
+
688
+ if partial_or_type:
689
+ precision = (correct + 0.5 * partial) / actual if actual > 0 else 0
690
+ recall = (correct + 0.5 * partial) / possible if possible > 0 else 0
691
+
692
+ else:
693
+ precision = correct / actual if actual > 0 else 0
694
+ recall = correct / possible if possible > 0 else 0
695
+
696
+ results["precision"] = precision
697
+ results["recall"] = recall
698
+ results["f1"] = (
699
+ precision * recall * 2 / (precision + recall) if precision + recall > 0 else 0
700
+ )
701
+
702
+ return results
703
+
704
+
705
+ def compute_precision_recall_f1_wrapper(results):
706
+ """
707
+ Wraps the compute_precision_recall_f1 function and runs on a dict of results
708
+ """
709
+
710
+ results_a = {
711
+ key: compute_precision_recall_f1(value, True)
712
+ for key, value in results.items()
713
+ if key in ["partial", "ent_type"]
714
+ }
715
+ results_b = {
716
+ key: compute_precision_recall_f1(value)
717
+ for key, value in results.items()
718
+ if key in ["strict", "exact"]
719
+ }
720
+
721
+ results = {**results_a, **results_b}
722
+
723
+ return results
tests.py DELETED
@@ -1,17 +0,0 @@
1
- test_cases = [
2
- {
3
- "predictions": [0, 0],
4
- "references": [1, 1],
5
- "result": {"metric_score": 0}
6
- },
7
- {
8
- "predictions": [1, 1],
9
- "references": [1, 1],
10
- "result": {"metric_score": 1}
11
- },
12
- {
13
- "predictions": [1, 0],
14
- "references": [1, 1],
15
- "result": {"metric_score": 0.5}
16
- }
17
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_ner_eval.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import pytest
3
+
4
+ ner_eval = evaluate.load("ner_eval.py")
5
+
6
+ test_cases = [
7
+ {
8
+ "predictions": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"],
9
+ "references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"],
10
+ "results": {
11
+ "overall": {
12
+ "strict_precision": 1.0,
13
+ "strict_recall": 1.0,
14
+ "strict_f1": 1.0,
15
+ "ent_type_precision": 1.0,
16
+ "ent_type_recall": 1.0,
17
+ "ent_type_f1": 1.0,
18
+ "partial_precision": 1.0,
19
+ "partial_recall": 1.0,
20
+ "partial_f1": 1.0,
21
+ "exact_precision": 1.0,
22
+ "exact_recall": 1.0,
23
+ "exact_f1": 1.0,
24
+ },
25
+ "LOC": {
26
+ "strict_precision": 1.0,
27
+ "strict_recall": 1.0,
28
+ "strict_f1": 1.0,
29
+ "ent_type_precision": 1.0,
30
+ "ent_type_recall": 1.0,
31
+ "ent_type_f1": 1.0,
32
+ "partial_precision": 1.0,
33
+ "partial_recall": 1.0,
34
+ "partial_f1": 1.0,
35
+ "exact_precision": 1.0,
36
+ "exact_recall": 1.0,
37
+ "exact_f1": 1.0,
38
+ },
39
+ "PER": {
40
+ "strict_precision": 1.0,
41
+ "strict_recall": 1.0,
42
+ "strict_f1": 1.0,
43
+ "ent_type_precision": 1.0,
44
+ "ent_type_recall": 1.0,
45
+ "ent_type_f1": 1.0,
46
+ "partial_precision": 1.0,
47
+ "partial_recall": 1.0,
48
+ "partial_f1": 1.0,
49
+ "exact_precision": 1.0,
50
+ "exact_recall": 1.0,
51
+ "exact_f1": 1.0,
52
+ },
53
+ "ORG": {
54
+ "strict_precision": 1.0,
55
+ "strict_recall": 1.0,
56
+ "strict_f1": 1.0,
57
+ "ent_type_precision": 1.0,
58
+ "ent_type_recall": 1.0,
59
+ "ent_type_f1": 1.0,
60
+ "partial_precision": 1.0,
61
+ "partial_recall": 1.0,
62
+ "partial_f1": 1.0,
63
+ "exact_precision": 1.0,
64
+ "exact_recall": 1.0,
65
+ "exact_f1": 1.0,
66
+ },
67
+ },
68
+ },
69
+ {
70
+ "predictions": [
71
+ "B-LOC",
72
+ "I-LOC",
73
+ "O",
74
+ "B-PER",
75
+ "I-PER",
76
+ "I-PER",
77
+ "I-PER",
78
+ "O",
79
+ "B-LOC",
80
+ "O",
81
+ ],
82
+ "references": [
83
+ "B-LOC",
84
+ "I-LOC",
85
+ "O",
86
+ "B-PER",
87
+ "I-PER",
88
+ "I-PER",
89
+ "I-PER",
90
+ "O",
91
+ "B-LOC",
92
+ "O",
93
+ ],
94
+ "results": {
95
+ "overall": {
96
+ "strict_precision": 1.0,
97
+ "strict_recall": 1.0,
98
+ "strict_f1": 1.0,
99
+ "ent_type_precision": 1.0,
100
+ "ent_type_recall": 1.0,
101
+ "ent_type_f1": 1.0,
102
+ "partial_precision": 1.0,
103
+ "partial_recall": 1.0,
104
+ "partial_f1": 1.0,
105
+ "exact_precision": 1.0,
106
+ "exact_recall": 1.0,
107
+ "exact_f1": 1.0,
108
+ },
109
+ "LOC": {
110
+ "strict_precision": 1.0,
111
+ "strict_recall": 1.0,
112
+ "strict_f1": 1.0,
113
+ "ent_type_precision": 1.0,
114
+ "ent_type_recall": 1.0,
115
+ "ent_type_f1": 1.0,
116
+ "partial_precision": 1.0,
117
+ "partial_recall": 1.0,
118
+ "partial_f1": 1.0,
119
+ "exact_precision": 1.0,
120
+ "exact_recall": 1.0,
121
+ "exact_f1": 1.0,
122
+ },
123
+ "PER": {
124
+ "strict_precision": 1.0,
125
+ "strict_recall": 1.0,
126
+ "strict_f1": 1.0,
127
+ "ent_type_precision": 1.0,
128
+ "ent_type_recall": 1.0,
129
+ "ent_type_f1": 1.0,
130
+ "partial_precision": 1.0,
131
+ "partial_recall": 1.0,
132
+ "partial_f1": 1.0,
133
+ "exact_precision": 1.0,
134
+ "exact_recall": 1.0,
135
+ "exact_f1": 1.0,
136
+ },
137
+ },
138
+ },
139
+ {
140
+ "predictions": ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "O", "B-ORG"],
141
+ "references": ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-ORG"],
142
+ },
143
+ {
144
+ "predictions": ["B-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG", "I-ORG"],
145
+ "references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG"],
146
+ "results": {
147
+ "overall": {
148
+ "strict_precision": 0.0,
149
+ "strict_recall": 0.0,
150
+ "strict_f1": 0,
151
+ "ent_type_precision": 0.0,
152
+ "ent_type_recall": 0.0,
153
+ "ent_type_f1": 0,
154
+ "partial_precision": 0.0,
155
+ "partial_recall": 0.0,
156
+ "partial_f1": 0,
157
+ "exact_precision": 0.0,
158
+ "exact_recall": 0.0,
159
+ "exact_f1": 0,
160
+ },
161
+ "ORG": {
162
+ "strict_precision": 0.0,
163
+ "strict_recall": 0.0,
164
+ "strict_f1": 0,
165
+ "ent_type_precision": 0.0,
166
+ "ent_type_recall": 0.0,
167
+ "ent_type_f1": 0,
168
+ "partial_precision": 0.0,
169
+ "partial_recall": 0.0,
170
+ "partial_f1": 0,
171
+ "exact_precision": 0.0,
172
+ "exact_recall": 0.0,
173
+ "exact_f1": 0,
174
+ },
175
+ "PER": {
176
+ "strict_precision": 0.0,
177
+ "strict_recall": 0.0,
178
+ "strict_f1": 0,
179
+ "ent_type_precision": 0.0,
180
+ "ent_type_recall": 0.0,
181
+ "ent_type_f1": 0,
182
+ "partial_precision": 0.0,
183
+ "partial_recall": 0.0,
184
+ "partial_f1": 0,
185
+ "exact_precision": 0.0,
186
+ "exact_recall": 0.0,
187
+ "exact_f1": 0,
188
+ },
189
+ "LOC": {
190
+ "strict_precision": 0.0,
191
+ "strict_recall": 0.0,
192
+ "strict_f1": 0,
193
+ "ent_type_precision": 0.0,
194
+ "ent_type_recall": 0.0,
195
+ "ent_type_f1": 0,
196
+ "partial_precision": 0.0,
197
+ "partial_recall": 0.0,
198
+ "partial_f1": 0,
199
+ "exact_precision": 0.0,
200
+ "exact_recall": 0.0,
201
+ "exact_f1": 0,
202
+ },
203
+ },
204
+ },
205
+ {
206
+ "predictions": [
207
+ "B-LOC",
208
+ "I-LOC",
209
+ "I-LOC",
210
+ "B-ORG",
211
+ "I-ORG",
212
+ "O",
213
+ "B-PER",
214
+ "I-PER",
215
+ "I-PER",
216
+ "O",
217
+ ],
218
+ "references": [
219
+ "B-LOC",
220
+ "I-LOC",
221
+ "O",
222
+ "O",
223
+ "B-ORG",
224
+ "I-ORG",
225
+ "O",
226
+ "B-PER",
227
+ "I-PER",
228
+ "O",
229
+ ],
230
+ "results": {
231
+ "overall": {
232
+ "strict_precision": 0.0,
233
+ "strict_recall": 0.0,
234
+ "strict_f1": 0,
235
+ "ent_type_precision": 2 / 3,
236
+ "ent_type_recall": 2 / 3,
237
+ "ent_type_f1": 2 / 3,
238
+ "partial_precision": 1 / 3,
239
+ "partial_recall": 1 / 3,
240
+ "partial_f1": 1 / 3,
241
+ "exact_precision": 0.0,
242
+ "exact_recall": 0.0,
243
+ "exact_f1": 0,
244
+ },
245
+ "ORG": {
246
+ "strict_precision": 0.0,
247
+ "strict_recall": 0.0,
248
+ "strict_f1": 0,
249
+ "ent_type_precision": 0.0,
250
+ "ent_type_recall": 0.0,
251
+ "ent_type_f1": 0,
252
+ "partial_precision": 0.0,
253
+ "partial_recall": 0.0,
254
+ "partial_f1": 0,
255
+ "exact_precision": 0.0,
256
+ "exact_recall": 0.0,
257
+ "exact_f1": 0,
258
+ },
259
+ "PER": {
260
+ "strict_precision": 0.0,
261
+ "strict_recall": 0.0,
262
+ "strict_f1": 0,
263
+ "ent_type_precision": 0.5,
264
+ "ent_type_recall": 1.0,
265
+ "ent_type_f1": 2 / 3,
266
+ "partial_precision": 0.25,
267
+ "partial_recall": 0.5,
268
+ "partial_f1": 1 / 3,
269
+ "exact_precision": 0.0,
270
+ "exact_recall": 0.0,
271
+ "exact_f1": 0,
272
+ },
273
+ "LOC": {
274
+ "strict_precision": 0.0,
275
+ "strict_recall": 0.0,
276
+ "strict_f1": 0,
277
+ "ent_type_precision": 0.5,
278
+ "ent_type_recall": 1.0,
279
+ "ent_type_f1": 2 / 3,
280
+ "partial_precision": 0.25,
281
+ "partial_recall": 0.5,
282
+ "partial_f1": 1 / 3,
283
+ "exact_precision": 0.0,
284
+ "exact_recall": 0.0,
285
+ "exact_f1": 0,
286
+ },
287
+ },
288
+ },
289
+ ]
290
+
291
+
292
+ def compare_results(result1, result2):
293
+ # recursively check if dictionaries are equal
294
+ if isinstance(result1, dict):
295
+ for key in result1.keys():
296
+ if not compare_results(result1[key], result2[key]):
297
+ return False
298
+ return True
299
+ elif isinstance(result1, list):
300
+ for item1, item2 in zip(result1, result2):
301
+ if not compare_results(item1, item2):
302
+ return False
303
+ return True
304
+ else:
305
+ return result1 == result2
306
+
307
+
308
+ @pytest.mark.parametrize("case", test_cases)
309
+ def test_metric(case):
310
+ if "results" not in case:
311
+ with pytest.raises(ValueError):
312
+ results = ner_eval.compute(
313
+ predictions=[case["predictions"]], references=[case["references"]]
314
+ )
315
+ else:
316
+ results = ner_eval.compute(
317
+ predictions=[case["predictions"]], references=[case["references"]]
318
+ )
319
+ assert compare_results(results, case["results"])