File size: 1,323 Bytes
05922fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import json
import os
import copy
from collections import defaultdict
from argparse import ArgumentParser
from tqdm import tqdm
import random
from tqdm import tqdm
from scripts.predict_concrete import read_kairos
from sftp import SpanPredictor
parser = ArgumentParser()
parser.add_argument('aida', type=str)
parser.add_argument('model', type=str)
parser.add_argument('dst', type=str)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()
k = args.topk
corpus = json.load(open(args.aida))
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device)
idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label')
random.seed(42)
random.shuffle(corpus)
output_fp = open(args.dst, 'a')
for line in tqdm(corpus):
tokens, ann = line['tokens'], line['annotation']
start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label']
prob_dist = predictor.force_decode(tokens, [(start, end)])[0]
topk_indices = prob_dist.argsort(descending=True)[:k]
prob = prob_dist[topk_indices].tolist()
frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)]
output_fp.write(json.dumps({
'tokens': tokens,
'frames': frames,
'kairos': kairos_label
}) + '\n')
|