File size: 5,571 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from typing import *

from allennlp.training.metrics import Metric
from overrides import overrides
import numpy as np
import logging

from .base_f import BaseF
from ..utils import Span, max_match

logger = logging.getLogger('srl_metric')


@Metric.register('srl')
class SRLMetric(Metric):
    def __init__(self, check_type: Optional[bool] = None):
        self.tri_i = BaseF('tri-i')
        self.tri_c = BaseF('tri-c')
        self.arg_i = BaseF('arg-i')
        self.arg_c = BaseF('arg-c')
        if check_type is not None:
            logger.warning('Check type argument is deprecated.')

    def reset(self) -> None:
        for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
            metric.reset()

    def get_metric(self, reset: bool) -> Dict[str, Any]:
        ret = dict()
        for metric in [self.tri_i, self.tri_c, self.arg_i, self.arg_c]:
            ret.update(metric.get_metric(reset))
        return ret

    @overrides
    def __call__(self, prediction: Span, gold: Span):
        self.with_label_event(prediction, gold)
        self.without_label_event(prediction, gold)
        self.tuple_eval(prediction, gold)
        # self.with_label_arg(prediction, gold)
        # self.without_label_arg(prediction, gold)

    def tuple_eval(self, prediction: Span, gold: Span):
        def extract_tuples(vr: Span, parent_boundary: bool):
            labeled, unlabeled = list(), list()
            for event in vr:
                for arg in event:
                    if parent_boundary:
                        labeled.append((event.boundary, event.label, arg.boundary, arg.label))
                        unlabeled.append((event.boundary, event.label, arg.boundary))
                    else:
                        labeled.append((event.label, arg.boundary, arg.label))
                        unlabeled.append((event.label, arg.boundary))
            return labeled, unlabeled

        def equal_matrix(l1, l2): return np.array([[e1 == e2 for e2 in l2] for e1 in l1], dtype=np.int)

        pred_label, pred_unlabel = extract_tuples(prediction, False)
        gold_label, gold_unlabel = extract_tuples(gold, False)

        if len(pred_label) == 0 or len(gold_label) == 0:
            arg_c_tp = arg_i_tp = 0
        else:
            label_bipartite = equal_matrix(pred_label, gold_label)
            unlabel_bipartite = equal_matrix(pred_unlabel, gold_unlabel)
            arg_c_tp, arg_i_tp = max_match(label_bipartite), max_match(unlabel_bipartite)

        arg_c_fp = prediction.n_nodes - len(prediction) - 1 - arg_c_tp
        arg_c_fn = gold.n_nodes - len(gold) - 1 - arg_c_tp
        arg_i_fp = prediction.n_nodes - len(prediction) - 1 - arg_i_tp
        arg_i_fn = gold.n_nodes - len(gold) - 1 - arg_i_tp

        assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0
        self.arg_i.tp += arg_i_tp
        self.arg_i.fp += arg_i_fp
        self.arg_i.fn += arg_i_fn

        assert arg_c_tp >= 0 and arg_c_fn >= 0 and arg_c_fp >= 0
        self.arg_c.tp += arg_c_tp
        self.arg_c.fp += arg_c_fp
        self.arg_c.fn += arg_c_fn

    def with_label_event(self, prediction: Span, gold: Span):
        trigger_tp = prediction.match(gold, True, 2) - 1
        trigger_fp = len(prediction) - trigger_tp
        trigger_fn = len(gold) - trigger_tp
        assert trigger_fp >= 0 and trigger_fn >= 0 and trigger_tp >= 0
        self.tri_c.tp += trigger_tp
        self.tri_c.fp += trigger_fp
        self.tri_c.fn += trigger_fn

    def with_label_arg(self, prediction: Span, gold: Span):
        trigger_tp = prediction.match(gold, True, 2) - 1
        role_tp = prediction.match(gold, True, ignore_parent_boundary=True) - 1 - trigger_tp
        role_fp = (prediction.n_nodes - 1 - len(prediction)) - role_tp
        role_fn = (gold.n_nodes - 1 - len(gold)) - role_tp
        assert role_fp >= 0 and role_fn >= 0 and role_tp >= 0
        self.arg_c.tp += role_tp
        self.arg_c.fp += role_fp
        self.arg_c.fn += role_fn

    def without_label_event(self, prediction: Span, gold: Span):
        tri_i_tp = prediction.match(gold, False, 2) - 1
        tri_i_fp = len(prediction) - tri_i_tp
        tri_i_fn = len(gold) - tri_i_tp
        assert tri_i_tp >= 0 and tri_i_fp >= 0 and tri_i_fn >= 0
        self.tri_i.tp += tri_i_tp
        self.tri_i.fp += tri_i_fp
        self.tri_i.fn += tri_i_fn

    def without_label_arg(self, prediction: Span, gold: Span):
        arg_i_tp = 0
        matched_pairs: List[Tuple[Span, Span]] = list()
        n_gold_arg, n_pred_arg = gold.n_nodes - len(gold) - 1, prediction.n_nodes - len(prediction) - 1
        prediction, gold = prediction.clone(), gold.clone()
        for p in prediction:
            for g in gold:
                if p.match(g, True, 1) == 1:
                    arg_i_tp += (p.match(g, False) - 1)
                    matched_pairs.append((p, g))
                    break
        for p, g in matched_pairs:
            prediction.remove_child(p)
            gold.remove_child(g)

        sub_matches = np.zeros([len(prediction), len(gold)], np.int)
        for p_idx, p in enumerate(prediction):
            for g_idx, g in enumerate(gold):
                if p.label == g.label:
                    sub_matches[p_idx, g_idx] = p.match(g, False, -1, True)
        arg_i_tp += max_match(sub_matches)

        arg_i_fp = n_pred_arg - arg_i_tp
        arg_i_fn = n_gold_arg - arg_i_tp
        assert arg_i_tp >= 0 and arg_i_fn >= 0 and arg_i_fp >= 0

        self.arg_i.tp += arg_i_tp
        self.arg_i.fp += arg_i_fp
        self.arg_i.fn += arg_i_fn