File size: 1,369 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from argparse import ArgumentParser
from typing import *
import json
import logging

from sftp import SpanPredictor

logger = logging.getLogger('ConcretePredictor')


def read_kairos(ontology_mapping_path: Optional[str] = None):
    # Legacy. For the old mapping file only.
    if ontology_mapping_path is None:
        return
    raw = json.load(open(ontology_mapping_path))
    fn2kairos = dict()
    for kairos_label in raw:
        for fn in raw[kairos_label]['framenet']:
            fn_label = fn['label']
            if fn_label in fn2kairos:
                logger.warning(f'"{fn_label}" is repeated in the ontology file.')
            fn2kairos[fn_label] = kairos_label
    return fn2kairos


def run(src, dst, model_path, ontology_mapping_path, device):
    mapping = SpanPredictor.read_ontology_mapping(ontology_mapping_path)
    predictor = SpanPredictor.from_path(model_path, cuda_device=device)
    predictor.predict_concrete(src, dst, ontology_mapping=mapping)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('src', type=str)
    parser.add_argument('dst', type=str)
    parser.add_argument('model', type=str)
    parser.add_argument('--map', type=str, default=None)
    parser.add_argument('--device', type=int, default=-1)
    args = parser.parse_args()
    run(args.src, args.dst, args.model, args.map, args.device)