sociolome / scripts /archive /eval_tie.py
Gosse Minnema
Initial commit
05922fb
import os
import sys
import json
from pprint import pprint
from collections import defaultdict
from sftp.metrics.exact_match import ExactMatch
def evaluate():
em = ExactMatch(True)
sm = ExactMatch(False)
gold_file, pred_file = sys.argv[1:]
test_sentences = {json.loads(line)['meta']['sentence ID']: json.loads(line) for line in open(gold_file).readlines()}
pred_sentences = defaultdict(list)
for line in open(pred_file).readlines():
one_pred = json.loads(line)
pred_sentences[one_pred['meta']['sentence ID']].append(one_pred)
for sent_id, gold_sent in test_sentences.items():
pred_sent = pred_sentences.get(sent_id, [])
pred_frames, pred_fes = [], []
for fr_idx, fr in enumerate(pred_sent):
pred_frames.append({key: fr[key] for key in ["start_idx", "end_idx", "label"]})
pred_frames[-1]['parent'] = 0
for fe in fr['children']:
pred_fes.append({key: fe[key] for key in ["start_idx", "end_idx", "label"]})
pred_fes[-1]['parent'] = fr_idx+1
pred_to_eval = pred_frames + pred_fes
gold_frames, gold_fes = [], []
for fr_idx, fr in enumerate(gold_sent['frame']):
gold_frames.append({
'start_idx': fr['target'][0], 'end_idx': fr['target'][-1], "label": fr['name'], 'parent': 0
})
for start_idx, end_idx, fe_name in fr['fe']:
gold_fes.append({
"start_idx": start_idx, "end_idx": end_idx, "label": fe_name, "parent": fr_idx+1
})
gold_to_eval = gold_frames + gold_fes
em(pred_to_eval, gold_to_eval)
sm(pred_to_eval, gold_to_eval)
print('EM')
pprint(em.get_metric(True))
print('SM')
pprint(sm.get_metric(True))
if __name__ == '__main__':
evaluate()