|
import os |
|
import sys |
|
import json |
|
from pprint import pprint |
|
from collections import defaultdict |
|
|
|
from sftp.metrics.exact_match import ExactMatch |
|
|
|
|
|
def evaluate(): |
|
em = ExactMatch(True) |
|
sm = ExactMatch(False) |
|
gold_file, pred_file = sys.argv[1:] |
|
test_sentences = {json.loads(line)['meta']['sentence ID']: json.loads(line) for line in open(gold_file).readlines()} |
|
pred_sentences = defaultdict(list) |
|
for line in open(pred_file).readlines(): |
|
one_pred = json.loads(line) |
|
pred_sentences[one_pred['meta']['sentence ID']].append(one_pred) |
|
for sent_id, gold_sent in test_sentences.items(): |
|
pred_sent = pred_sentences.get(sent_id, []) |
|
pred_frames, pred_fes = [], [] |
|
for fr_idx, fr in enumerate(pred_sent): |
|
pred_frames.append({key: fr[key] for key in ["start_idx", "end_idx", "label"]}) |
|
pred_frames[-1]['parent'] = 0 |
|
for fe in fr['children']: |
|
pred_fes.append({key: fe[key] for key in ["start_idx", "end_idx", "label"]}) |
|
pred_fes[-1]['parent'] = fr_idx+1 |
|
pred_to_eval = pred_frames + pred_fes |
|
|
|
gold_frames, gold_fes = [], [] |
|
for fr_idx, fr in enumerate(gold_sent['frame']): |
|
gold_frames.append({ |
|
'start_idx': fr['target'][0], 'end_idx': fr['target'][-1], "label": fr['name'], 'parent': 0 |
|
}) |
|
for start_idx, end_idx, fe_name in fr['fe']: |
|
gold_fes.append({ |
|
"start_idx": start_idx, "end_idx": end_idx, "label": fe_name, "parent": fr_idx+1 |
|
}) |
|
gold_to_eval = gold_frames + gold_fes |
|
em(pred_to_eval, gold_to_eval) |
|
sm(pred_to_eval, gold_to_eval) |
|
|
|
print('EM') |
|
pprint(em.get_metric(True)) |
|
print('SM') |
|
pprint(sm.get_metric(True)) |
|
|
|
|
|
if __name__ == '__main__': |
|
evaluate() |
|
|