import json import os import copy from collections import defaultdict from argparse import ArgumentParser from tqdm import tqdm import random from tqdm import tqdm from scripts.predict_concrete import read_kairos from sftp import SpanPredictor parser = ArgumentParser() parser.add_argument('aida', type=str) parser.add_argument('model', type=str) parser.add_argument('dst', type=str) parser.add_argument('--topk', type=int, default=10) parser.add_argument('--device', type=int, default=0) args = parser.parse_args() k = args.topk corpus = json.load(open(args.aida)) predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label') random.seed(42) random.shuffle(corpus) output_fp = open(args.dst, 'a') for line in tqdm(corpus): tokens, ann = line['tokens'], line['annotation'] start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label'] prob_dist = predictor.force_decode(tokens, [(start, end)])[0] topk_indices = prob_dist.argsort(descending=True)[:k] prob = prob_dist[topk_indices].tolist() frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)] output_fp.write(json.dumps({ 'tokens': tokens, 'frames': frames, 'kairos': kairos_label }) + '\n')