Gosse Minnema
Initial commit
05922fb
import json
import os
import copy
from collections import defaultdict
from argparse import ArgumentParser
from tqdm import tqdm
import random
from tqdm import tqdm
from scripts.predict_concrete import read_kairos
from sftp import SpanPredictor
parser = ArgumentParser()
parser.add_argument('aida', type=str)
parser.add_argument('model', type=str)
parser.add_argument('dst', type=str)
parser.add_argument('--topk', type=int, default=10)
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()
k = args.topk
corpus = json.load(open(args.aida))
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device)
idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label')
random.seed(42)
random.shuffle(corpus)
output_fp = open(args.dst, 'a')
for line in tqdm(corpus):
tokens, ann = line['tokens'], line['annotation']
start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label']
prob_dist = predictor.force_decode(tokens, [(start, end)])[0]
topk_indices = prob_dist.argsort(descending=True)[:k]
prob = prob_dist[topk_indices].tolist()
frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)]
output_fp.write(json.dumps({
'tokens': tokens,
'frames': frames,
'kairos': kairos_label
}) + '\n')