sociolome / scripts /predict_concrete.py
Gosse Minnema
Initial commit
05922fb
from argparse import ArgumentParser
from typing import *
import json
import logging
from sftp import SpanPredictor
logger = logging.getLogger('ConcretePredictor')
def read_kairos(ontology_mapping_path: Optional[str] = None):
# Legacy. For the old mapping file only.
if ontology_mapping_path is None:
return
raw = json.load(open(ontology_mapping_path))
fn2kairos = dict()
for kairos_label in raw:
for fn in raw[kairos_label]['framenet']:
fn_label = fn['label']
if fn_label in fn2kairos:
logger.warning(f'"{fn_label}" is repeated in the ontology file.')
fn2kairos[fn_label] = kairos_label
return fn2kairos
def run(src, dst, model_path, ontology_mapping_path, device):
mapping = SpanPredictor.read_ontology_mapping(ontology_mapping_path)
predictor = SpanPredictor.from_path(model_path, cuda_device=device)
predictor.predict_concrete(src, dst, ontology_mapping=mapping)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('src', type=str)
parser.add_argument('dst', type=str)
parser.add_argument('model', type=str)
parser.add_argument('--map', type=str, default=None)
parser.add_argument('--device', type=int, default=-1)
args = parser.parse_args()
run(args.src, args.dst, args.model, args.map, args.device)