Spaces:
Sleeping
Sleeping
File size: 1,369 Bytes
05922fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from argparse import ArgumentParser
from typing import *
import json
import logging
from sftp import SpanPredictor
logger = logging.getLogger('ConcretePredictor')
def read_kairos(ontology_mapping_path: Optional[str] = None):
# Legacy. For the old mapping file only.
if ontology_mapping_path is None:
return
raw = json.load(open(ontology_mapping_path))
fn2kairos = dict()
for kairos_label in raw:
for fn in raw[kairos_label]['framenet']:
fn_label = fn['label']
if fn_label in fn2kairos:
logger.warning(f'"{fn_label}" is repeated in the ontology file.')
fn2kairos[fn_label] = kairos_label
return fn2kairos
def run(src, dst, model_path, ontology_mapping_path, device):
mapping = SpanPredictor.read_ontology_mapping(ontology_mapping_path)
predictor = SpanPredictor.from_path(model_path, cuda_device=device)
predictor.predict_concrete(src, dst, ontology_mapping=mapping)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('src', type=str)
parser.add_argument('dst', type=str)
parser.add_argument('model', type=str)
parser.add_argument('--map', type=str, default=None)
parser.add_argument('--device', type=int, default=-1)
args = parser.parse_args()
run(args.src, args.dst, args.model, args.map, args.device)
|