|
from allennlp.training.metrics import Metric |
|
from overrides import overrides |
|
|
|
from .base_f import BaseF |
|
from ..utils import Span |
|
|
|
|
|
@Metric.register('exact_match') |
|
class ExactMatch(BaseF): |
|
def __init__(self, check_type: bool): |
|
self.check_type = check_type |
|
if check_type: |
|
super(ExactMatch, self).__init__('em') |
|
else: |
|
super(ExactMatch, self).__init__('sm') |
|
|
|
@overrides |
|
def __call__( |
|
self, |
|
prediction: Span, |
|
gold: Span, |
|
): |
|
tp = prediction.match(gold, self.check_type) - 1 |
|
fp = prediction.n_nodes - tp - 1 |
|
fn = gold.n_nodes - tp - 1 |
|
assert tp >= 0 and fp >= 0 and fn >= 0 |
|
self.tp += tp |
|
self.fp += fp |
|
self.fn += fn |
|
|