File size: 765 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from allennlp.training.metrics import Metric
from overrides import overrides

from .base_f import BaseF
from ..utils import Span


@Metric.register('exact_match')
class ExactMatch(BaseF):
    def __init__(self, check_type: bool):
        self.check_type = check_type
        if check_type:
            super(ExactMatch, self).__init__('em')
        else:
            super(ExactMatch, self).__init__('sm')

    @overrides
    def __call__(
            self,
            prediction: Span,
            gold: Span,
    ):
        tp = prediction.match(gold, self.check_type) - 1
        fp = prediction.n_nodes - tp - 1
        fn = gold.n_nodes - tp - 1
        assert tp >= 0 and fp >= 0 and fn >= 0
        self.tp += tp
        self.fp += fp
        self.fn += fn