File size: 1,323 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import json
import os
import copy
from collections import defaultdict
from argparse import ArgumentParser
from tqdm import tqdm
import random
from tqdm import tqdm
from scripts.predict_concrete import read_kairos

from sftp import SpanPredictor


parser = ArgumentParser()
parser.add_argument('aida', type=str)
parser.add_argument('model', type=str)
parser.add_argument('dst', type=str)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()

k = args.topk
corpus = json.load(open(args.aida))
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device)
idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label')
random.seed(42)
random.shuffle(corpus)


output_fp = open(args.dst, 'a')
for line in tqdm(corpus):
    tokens, ann = line['tokens'], line['annotation']
    start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label']
    prob_dist = predictor.force_decode(tokens, [(start, end)])[0]
    topk_indices = prob_dist.argsort(descending=True)[:k]
    prob = prob_dist[topk_indices].tolist()
    frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)]
    output_fp.write(json.dumps({
        'tokens': tokens,
        'frames': frames,
        'kairos': kairos_label
    }) + '\n')