from argparse import ArgumentParser from typing import * import json import logging from sftp import SpanPredictor logger = logging.getLogger('ConcretePredictor') def read_kairos(ontology_mapping_path: Optional[str] = None): # Legacy. For the old mapping file only. if ontology_mapping_path is None: return raw = json.load(open(ontology_mapping_path)) fn2kairos = dict() for kairos_label in raw: for fn in raw[kairos_label]['framenet']: fn_label = fn['label'] if fn_label in fn2kairos: logger.warning(f'"{fn_label}" is repeated in the ontology file.') fn2kairos[fn_label] = kairos_label return fn2kairos def run(src, dst, model_path, ontology_mapping_path, device): mapping = SpanPredictor.read_ontology_mapping(ontology_mapping_path) predictor = SpanPredictor.from_path(model_path, cuda_device=device) predictor.predict_concrete(src, dst, ontology_mapping=mapping) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('src', type=str) parser.add_argument('dst', type=str) parser.add_argument('model', type=str) parser.add_argument('--map', type=str, default=None) parser.add_argument('--device', type=int, default=-1) args = parser.parse_args() run(args.src, args.dst, args.model, args.map, args.device)