import os import sys import json from pprint import pprint from collections import defaultdict from sftp.metrics.exact_match import ExactMatch def evaluate(): em = ExactMatch(True) sm = ExactMatch(False) gold_file, pred_file = sys.argv[1:] test_sentences = {json.loads(line)['meta']['sentence ID']: json.loads(line) for line in open(gold_file).readlines()} pred_sentences = defaultdict(list) for line in open(pred_file).readlines(): one_pred = json.loads(line) pred_sentences[one_pred['meta']['sentence ID']].append(one_pred) for sent_id, gold_sent in test_sentences.items(): pred_sent = pred_sentences.get(sent_id, []) pred_frames, pred_fes = [], [] for fr_idx, fr in enumerate(pred_sent): pred_frames.append({key: fr[key] for key in ["start_idx", "end_idx", "label"]}) pred_frames[-1]['parent'] = 0 for fe in fr['children']: pred_fes.append({key: fe[key] for key in ["start_idx", "end_idx", "label"]}) pred_fes[-1]['parent'] = fr_idx+1 pred_to_eval = pred_frames + pred_fes gold_frames, gold_fes = [], [] for fr_idx, fr in enumerate(gold_sent['frame']): gold_frames.append({ 'start_idx': fr['target'][0], 'end_idx': fr['target'][-1], "label": fr['name'], 'parent': 0 }) for start_idx, end_idx, fe_name in fr['fe']: gold_fes.append({ "start_idx": start_idx, "end_idx": end_idx, "label": fe_name, "parent": fr_idx+1 }) gold_to_eval = gold_frames + gold_fes em(pred_to_eval, gold_to_eval) sm(pred_to_eval, gold_to_eval) print('EM') pprint(em.get_metric(True)) print('SM') pprint(sm.get_metric(True)) if __name__ == '__main__': evaluate()