|
from typing import * |
|
|
|
from allennlp.training.metrics import Metric |
|
from overrides import overrides |
|
import numpy as np |
|
import logging |
|
|
|
from .base_f import BaseF |
|
from ..utils import Span, max_match |
|
|
|
logger = logging.getLogger('srl_metric') |
|
|
|
|
|
@Metric.register('srl') |
|
class SRLMetric(Metric): |
|
def __init__(self, check_type: Optional[bool] = None): |
|
self.tri_i = BaseF('tri-i') |
|
self.tri_c = BaseF('tri-c') |
|
self.arg_i = BaseF('arg-i') |
|
self.arg_c = BaseF('arg-c') |
|
if check_type is not None: |
|
logger.warning('Check type argument is deprecated.') |
|
|
|
def reset(self) -> None: |
|
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]: |
|
metric.reset() |
|
|
|
def get_metric(self, reset: bool) -> Dict[str, Any]: |
|
ret = dict() |
|
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]: |
|
ret.update(metric.get_metric(reset)) |
|
return ret |
|
|
|
@overrides |
|
def __call__(self, prediction: Span, gold: Span): |
|
self.with_label_event(prediction, gold) |
|
self.without_label_event(prediction, gold) |
|
self.tuple_eval(prediction, gold) |
|
|
|
|
|
|
|
def tuple_eval(self, prediction: Span, gold: Span): |
|
def extract_tuples(vr: Span, parent_boundary: bool): |
|
labeled, unlabeled = list(), list() |
|
for event in vr: |
|
for arg in event: |
|
if parent_boundary: |
|
labeled.append((event.boundary, event.label, arg.boundary, arg.label)) |
|
unlabeled.append((event.boundary, event.label, arg.boundary)) |
|
else: |
|
labeled.append((event.label, arg.boundary, arg.label)) |
|
unlabeled.append((event.label, arg.boundary)) |
|
return labeled, unlabeled |
|
|
|
def equal_matrix(l1, l2): return np.array([[e1 == e2 for e2 in l2] for e1 in l1], dtype=np.int) |
|
|
|
pred_label, pred_unlabel = extract_tuples(prediction, False) |
|
gold_label, gold_unlabel = extract_tuples(gold, False) |
|
|
|
if len(pred_label) == 0 or len(gold_label) == 0: |
|
arg_c_tp = arg_i_tp = 0 |
|
else: |
|
label_bipartite = equal_matrix(pred_label, gold_label) |
|
unlabel_bipartite = equal_matrix(pred_unlabel, gold_unlabel) |
|
arg_c_tp, arg_i_tp = max_match(label_bipartite), max_match(unlabel_bipartite) |
|
|
|
arg_c_fp = prediction.n_nodes - len(prediction) - 1 - arg_c_tp |
|
arg_c_fn = gold.n_nodes - len(gold) - 1 - arg_c_tp |
|
arg_i_fp = prediction.n_nodes - len(prediction) - 1 - arg_i_tp |
|
arg_i_fn = gold.n_nodes - len(gold) - 1 - arg_i_tp |
|
|
|
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0 |
|
self.arg_i.tp += arg_i_tp |
|
self.arg_i.fp += arg_i_fp |
|
self.arg_i.fn += arg_i_fn |
|
|
|
assert arg_c_tp >= 0 and arg_c_fn >= 0 and arg_c_fp >= 0 |
|
self.arg_c.tp += arg_c_tp |
|
self.arg_c.fp += arg_c_fp |
|
self.arg_c.fn += arg_c_fn |
|
|
|
def with_label_event(self, prediction: Span, gold: Span): |
|
trigger_tp = prediction.match(gold, True, 2) - 1 |
|
trigger_fp = len(prediction) - trigger_tp |
|
trigger_fn = len(gold) - trigger_tp |
|
assert trigger_fp >= 0 and trigger_fn >= 0 and trigger_tp >= 0 |
|
self.tri_c.tp += trigger_tp |
|
self.tri_c.fp += trigger_fp |
|
self.tri_c.fn += trigger_fn |
|
|
|
def with_label_arg(self, prediction: Span, gold: Span): |
|
trigger_tp = prediction.match(gold, True, 2) - 1 |
|
role_tp = prediction.match(gold, True, ignore_parent_boundary=True) - 1 - trigger_tp |
|
role_fp = (prediction.n_nodes - 1 - len(prediction)) - role_tp |
|
role_fn = (gold.n_nodes - 1 - len(gold)) - role_tp |
|
assert role_fp >= 0 and role_fn >= 0 and role_tp >= 0 |
|
self.arg_c.tp += role_tp |
|
self.arg_c.fp += role_fp |
|
self.arg_c.fn += role_fn |
|
|
|
def without_label_event(self, prediction: Span, gold: Span): |
|
tri_i_tp = prediction.match(gold, False, 2) - 1 |
|
tri_i_fp = len(prediction) - tri_i_tp |
|
tri_i_fn = len(gold) - tri_i_tp |
|
assert tri_i_tp >= 0 and tri_i_fp >= 0 and tri_i_fn >= 0 |
|
self.tri_i.tp += tri_i_tp |
|
self.tri_i.fp += tri_i_fp |
|
self.tri_i.fn += tri_i_fn |
|
|
|
def without_label_arg(self, prediction: Span, gold: Span): |
|
arg_i_tp = 0 |
|
matched_pairs: List[Tuple[Span, Span]] = list() |
|
n_gold_arg, n_pred_arg = gold.n_nodes - len(gold) - 1, prediction.n_nodes - len(prediction) - 1 |
|
prediction, gold = prediction.clone(), gold.clone() |
|
for p in prediction: |
|
for g in gold: |
|
if p.match(g, True, 1) == 1: |
|
arg_i_tp += (p.match(g, False) - 1) |
|
matched_pairs.append((p, g)) |
|
break |
|
for p, g in matched_pairs: |
|
prediction.remove_child(p) |
|
gold.remove_child(g) |
|
|
|
sub_matches = np.zeros([len(prediction), len(gold)], np.int) |
|
for p_idx, p in enumerate(prediction): |
|
for g_idx, g in enumerate(gold): |
|
if p.label == g.label: |
|
sub_matches[p_idx, g_idx] = p.match(g, False, -1, True) |
|
arg_i_tp += max_match(sub_matches) |
|
|
|
arg_i_fp = n_pred_arg - arg_i_tp |
|
arg_i_fn = n_gold_arg - arg_i_tp |
|
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0 |
|
|
|
self.arg_i.tp += arg_i_tp |
|
self.arg_i.fp += arg_i_fp |
|
self.arg_i.fn += arg_i_fn |
|
|