import json from argparse import ArgumentParser from collections import defaultdict import numpy as np from tqdm import tqdm from nltk.corpus import framenet as fn from sftp import SpanPredictor def run(model_path, data_path, device, use_ontology=False): data = list(map(json.loads, open(data_path).readlines())) lu2frame = defaultdict(list) for lu in fn.lus(): lu2frame[lu.name].append(lu.frame.name) predictor = SpanPredictor.from_path(model_path, cuda_device=device) frame2idx = predictor._model.vocab.get_token_to_index_vocabulary('span_label') all_frames = [fr.name for fr in fn.frames()] n_positive = n_total = 0 with tqdm(total=len(data)) as bar: for sent in data: bar.update() for point in sent['annotations']: model_output = predictor.force_decode( sent['tokens'], child_spans=[(point['span'][0], point['span'][-1])] ).distribution[0] if use_ontology: candidate_frames = lu2frame[point['lu']] else: candidate_frames = all_frames candidate_prob = [-1.0 for _ in candidate_frames] for idx_can, fr in enumerate(candidate_frames): if fr in frame2idx: candidate_prob[idx_can] = model_output[frame2idx[fr]] if len(candidate_prob) > 0: pred_frame = candidate_frames[int(np.argmax(candidate_prob))] if pred_frame == point['label']: n_positive += 1 n_total += 1 bar.set_description(f'acc={n_positive/n_total*100:.3f}') print(f'acc={n_positive/n_total*100:.3f}') if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('model', metavar="MODEL") parser.add_argument('data', metavar="DATA") parser.add_argument('-d', default=-1, type=int, help='Device') parser.add_argument('-o', action='store_true', help='Flag to use ontology.') args = parser.parse_args() run(args.model, args.data, args.d, args.o)