Spaces:
Sleeping
Sleeping
import json | |
import os | |
import copy | |
from collections import defaultdict | |
from argparse import ArgumentParser | |
from tqdm import tqdm | |
import random | |
from tqdm import tqdm | |
from scripts.predict_concrete import read_kairos | |
from sftp import SpanPredictor | |
parser = ArgumentParser() | |
parser.add_argument('aida', type=str) | |
parser.add_argument('model', type=str) | |
parser.add_argument('dst', type=str) | |
parser.add_argument('--topk', type=int, default=10) | |
parser.add_argument('--device', type=int, default=0) | |
args = parser.parse_args() | |
k = args.topk | |
corpus = json.load(open(args.aida)) | |
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) | |
idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label') | |
random.seed(42) | |
random.shuffle(corpus) | |
output_fp = open(args.dst, 'a') | |
for line in tqdm(corpus): | |
tokens, ann = line['tokens'], line['annotation'] | |
start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label'] | |
prob_dist = predictor.force_decode(tokens, [(start, end)])[0] | |
topk_indices = prob_dist.argsort(descending=True)[:k] | |
prob = prob_dist[topk_indices].tolist() | |
frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)] | |
output_fp.write(json.dumps({ | |
'tokens': tokens, | |
'frames': frames, | |
'kairos': kairos_label | |
}) + '\n') | |