Spaces:
Build error
Build error
File size: 5,571 Bytes
6680682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from typing import *
from allennlp.training.metrics import Metric
from overrides import overrides
import numpy as np
import logging
from .base_f import BaseF
from ..utils import Span, max_match
logger = logging.getLogger('srl_metric')
@Metric.register('srl')
class SRLMetric(Metric):
def __init__(self, check_type: Optional[bool] = None):
self.tri_i = BaseF('tri-i')
self.tri_c = BaseF('tri-c')
self.arg_i = BaseF('arg-i')
self.arg_c = BaseF('arg-c')
if check_type is not None:
logger.warning('Check type argument is deprecated.')
def reset(self) -> None:
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
metric.reset()
def get_metric(self, reset: bool) -> Dict[str, Any]:
ret = dict()
for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
ret.update(metric.get_metric(reset))
return ret
@overrides
def __call__(self, prediction: Span, gold: Span):
self.with_label_event(prediction, gold)
self.without_label_event(prediction, gold)
self.tuple_eval(prediction, gold)
# self.with_label_arg(prediction, gold)
# self.without_label_arg(prediction, gold)
def tuple_eval(self, prediction: Span, gold: Span):
def extract_tuples(vr: Span, parent_boundary: bool):
labeled, unlabeled = list(), list()
for event in vr:
for arg in event:
if parent_boundary:
labeled.append((event.boundary, event.label, arg.boundary, arg.label))
unlabeled.append((event.boundary, event.label, arg.boundary))
else:
labeled.append((event.label, arg.boundary, arg.label))
unlabeled.append((event.label, arg.boundary))
return labeled, unlabeled
def equal_matrix(l1, l2): return np.array([[e1 == e2 for e2 in l2] for e1 in l1], dtype=np.int)
pred_label, pred_unlabel = extract_tuples(prediction, False)
gold_label, gold_unlabel = extract_tuples(gold, False)
if len(pred_label) == 0 or len(gold_label) == 0:
arg_c_tp = arg_i_tp = 0
else:
label_bipartite = equal_matrix(pred_label, gold_label)
unlabel_bipartite = equal_matrix(pred_unlabel, gold_unlabel)
arg_c_tp, arg_i_tp = max_match(label_bipartite), max_match(unlabel_bipartite)
arg_c_fp = prediction.n_nodes - len(prediction) - 1 - arg_c_tp
arg_c_fn = gold.n_nodes - len(gold) - 1 - arg_c_tp
arg_i_fp = prediction.n_nodes - len(prediction) - 1 - arg_i_tp
arg_i_fn = gold.n_nodes - len(gold) - 1 - arg_i_tp
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
self.arg_i.tp += arg_i_tp
self.arg_i.fp += arg_i_fp
self.arg_i.fn += arg_i_fn
assert arg_c_tp >= 0 and arg_c_fn >= 0 and arg_c_fp >= 0
self.arg_c.tp += arg_c_tp
self.arg_c.fp += arg_c_fp
self.arg_c.fn += arg_c_fn
def with_label_event(self, prediction: Span, gold: Span):
trigger_tp = prediction.match(gold, True, 2) - 1
trigger_fp = len(prediction) - trigger_tp
trigger_fn = len(gold) - trigger_tp
assert trigger_fp >= 0 and trigger_fn >= 0 and trigger_tp >= 0
self.tri_c.tp += trigger_tp
self.tri_c.fp += trigger_fp
self.tri_c.fn += trigger_fn
def with_label_arg(self, prediction: Span, gold: Span):
trigger_tp = prediction.match(gold, True, 2) - 1
role_tp = prediction.match(gold, True, ignore_parent_boundary=True) - 1 - trigger_tp
role_fp = (prediction.n_nodes - 1 - len(prediction)) - role_tp
role_fn = (gold.n_nodes - 1 - len(gold)) - role_tp
assert role_fp >= 0 and role_fn >= 0 and role_tp >= 0
self.arg_c.tp += role_tp
self.arg_c.fp += role_fp
self.arg_c.fn += role_fn
def without_label_event(self, prediction: Span, gold: Span):
tri_i_tp = prediction.match(gold, False, 2) - 1
tri_i_fp = len(prediction) - tri_i_tp
tri_i_fn = len(gold) - tri_i_tp
assert tri_i_tp >= 0 and tri_i_fp >= 0 and tri_i_fn >= 0
self.tri_i.tp += tri_i_tp
self.tri_i.fp += tri_i_fp
self.tri_i.fn += tri_i_fn
def without_label_arg(self, prediction: Span, gold: Span):
arg_i_tp = 0
matched_pairs: List[Tuple[Span, Span]] = list()
n_gold_arg, n_pred_arg = gold.n_nodes - len(gold) - 1, prediction.n_nodes - len(prediction) - 1
prediction, gold = prediction.clone(), gold.clone()
for p in prediction:
for g in gold:
if p.match(g, True, 1) == 1:
arg_i_tp += (p.match(g, False) - 1)
matched_pairs.append((p, g))
break
for p, g in matched_pairs:
prediction.remove_child(p)
gold.remove_child(g)
sub_matches = np.zeros([len(prediction), len(gold)], np.int)
for p_idx, p in enumerate(prediction):
for g_idx, g in enumerate(gold):
if p.label == g.label:
sub_matches[p_idx, g_idx] = p.match(g, False, -1, True)
arg_i_tp += max_match(sub_matches)
arg_i_fp = n_pred_arg - arg_i_tp
arg_i_fn = n_gold_arg - arg_i_tp
assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
self.arg_i.tp += arg_i_tp
self.arg_i.fp += arg_i_fp
self.arg_i.fn += arg_i_fn
|