File size: 3,565 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import argparse
from xml.etree import ElementTree
import copy
from operator import attrgetter
import json
import logging

from sftp import SpanPredictor


def predict_kairos(model_archive, source_folder, onto_map):
    xml_files = list()
    for root, _, files in os.walk(source_folder):
        for f in files:
            if f.endswith('.xml'):
                xml_files.append(os.path.join(root, f))
    logging.info(f'{len(xml_files)} files are found:')
    for fn in xml_files:
        logging.info('  - ' + fn)

    logging.info('Loading ontology from ' + onto_map)
    k_map = dict()
    for kairos_event, content in json.load(open(onto_map)).items():
        for fr in content['framenet']:
            if fr['label'] in k_map:
                logging.info("Duplicate frame: " + fr['label'])
            k_map[fr['label']] = kairos_event

    logging.info('Loading model from ' + model_archive + ' ...')
    predictor = SpanPredictor.from_path(model_archive)

    predictions = list()

    for fn in xml_files:
        logging.info('Now processing ' + os.path.basename(fn))
        tree = ElementTree.parse(fn).getroot()
        for doc in tree:
            doc_meta = copy.deepcopy(doc.attrib)
            text = list(doc)[0]
            for seg in text:
                seg_meta = copy.deepcopy(doc_meta)
                seg_meta['seg'] = copy.deepcopy(seg.attrib)
                tokens = [child for child in seg if child.tag == 'TOKEN']
                tokens.sort(key=lambda t: t.attrib['start_char'])
                words = list(map(attrgetter('text'), tokens))
                one_pred = predictor.predict_sentence(words)
                one_pred['meta'] = seg_meta

                new_frames = list()
                for fr in one_pred['prediction']:
                    if fr['label'] in k_map:
                        fr['label'] = k_map[fr['label']]
                        new_frames.append(fr)
                one_pred['prediction'] = new_frames

                predictions.append(one_pred)

    logging.info('Finished Prediction.')

    return predictions


def do_task(input_dir, model_archive, onto_map):
    """
    This function is called by the KAIROS infrastructure code for each
    TASK1 input.
    """

    return predict_kairos(model_archive=model_archive,
                          source_folder=input_dir,
                          onto_map=onto_map)


def run():
    parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n')
    parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.')
    parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.')
    parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.')
    parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)')
    args = parser.parse_args()

    logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s")

    predictions = predict_kairos(model_archive=args.model_archive,
                                 source_folder=args.source_folder,
                                 onto_map=args.onto_map)

    logging.info('Saving to ' + args.destination + ' ...')
    os.makedirs(os.path.dirname(args.destination), exist_ok=True)
    with open(args.destination, 'w') as fp:
        fp.write('\n'.join(map(json.dumps, predictions)))

    logging.info('Done.')


if __name__ == '__main__':
    run()