sociolome / scripts /archive /predict_kairos.py
Gosse Minnema
Initial commit
05922fb
import os
import argparse
from xml.etree import ElementTree
import copy
from operator import attrgetter
import json
import logging
from sftp import SpanPredictor
def predict_kairos(model_archive, source_folder, onto_map):
xml_files = list()
for root, _, files in os.walk(source_folder):
for f in files:
if f.endswith('.xml'):
xml_files.append(os.path.join(root, f))
logging.info(f'{len(xml_files)} files are found:')
for fn in xml_files:
logging.info(' - ' + fn)
logging.info('Loading ontology from ' + onto_map)
k_map = dict()
for kairos_event, content in json.load(open(onto_map)).items():
for fr in content['framenet']:
if fr['label'] in k_map:
logging.info("Duplicate frame: " + fr['label'])
k_map[fr['label']] = kairos_event
logging.info('Loading model from ' + model_archive + ' ...')
predictor = SpanPredictor.from_path(model_archive)
predictions = list()
for fn in xml_files:
logging.info('Now processing ' + os.path.basename(fn))
tree = ElementTree.parse(fn).getroot()
for doc in tree:
doc_meta = copy.deepcopy(doc.attrib)
text = list(doc)[0]
for seg in text:
seg_meta = copy.deepcopy(doc_meta)
seg_meta['seg'] = copy.deepcopy(seg.attrib)
tokens = [child for child in seg if child.tag == 'TOKEN']
tokens.sort(key=lambda t: t.attrib['start_char'])
words = list(map(attrgetter('text'), tokens))
one_pred = predictor.predict_sentence(words)
one_pred['meta'] = seg_meta
new_frames = list()
for fr in one_pred['prediction']:
if fr['label'] in k_map:
fr['label'] = k_map[fr['label']]
new_frames.append(fr)
one_pred['prediction'] = new_frames
predictions.append(one_pred)
logging.info('Finished Prediction.')
return predictions
def do_task(input_dir, model_archive, onto_map):
"""
This function is called by the KAIROS infrastructure code for each
TASK1 input.
"""
return predict_kairos(model_archive=model_archive,
source_folder=input_dir,
onto_map=onto_map)
def run():
parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n')
parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.')
parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.')
parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.')
parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)')
args = parser.parse_args()
logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
predictions = predict_kairos(model_archive=args.model_archive,
source_folder=args.source_folder,
onto_map=args.onto_map)
logging.info('Saving to ' + args.destination + ' ...')
os.makedirs(os.path.dirname(args.destination), exist_ok=True)
with open(args.destination, 'w') as fp:
fp.write('\n'.join(map(json.dumps, predictions)))
logging.info('Done.')
if __name__ == '__main__':
run()