sociolome / scripts /archive /predict_kairos.py
Gosse Minnema
Initial commit
05922fb
raw
history blame
3.57 kB
import os
import argparse
from xml.etree import ElementTree
import copy
from operator import attrgetter
import json
import logging
from sftp import SpanPredictor
def predict_kairos(model_archive, source_folder, onto_map):
xml_files = list()
for root, _, files in os.walk(source_folder):
for f in files:
if f.endswith('.xml'):
xml_files.append(os.path.join(root, f))
logging.info(f'{len(xml_files)} files are found:')
for fn in xml_files:
logging.info(' - ' + fn)
logging.info('Loading ontology from ' + onto_map)
k_map = dict()
for kairos_event, content in json.load(open(onto_map)).items():
for fr in content['framenet']:
if fr['label'] in k_map:
logging.info("Duplicate frame: " + fr['label'])
k_map[fr['label']] = kairos_event
logging.info('Loading model from ' + model_archive + ' ...')
predictor = SpanPredictor.from_path(model_archive)
predictions = list()
for fn in xml_files:
logging.info('Now processing ' + os.path.basename(fn))
tree = ElementTree.parse(fn).getroot()
for doc in tree:
doc_meta = copy.deepcopy(doc.attrib)
text = list(doc)[0]
for seg in text:
seg_meta = copy.deepcopy(doc_meta)
seg_meta['seg'] = copy.deepcopy(seg.attrib)
tokens = [child for child in seg if child.tag == 'TOKEN']
tokens.sort(key=lambda t: t.attrib['start_char'])
words = list(map(attrgetter('text'), tokens))
one_pred = predictor.predict_sentence(words)
one_pred['meta'] = seg_meta
new_frames = list()
for fr in one_pred['prediction']:
if fr['label'] in k_map:
fr['label'] = k_map[fr['label']]
new_frames.append(fr)
one_pred['prediction'] = new_frames
predictions.append(one_pred)
logging.info('Finished Prediction.')
return predictions
def do_task(input_dir, model_archive, onto_map):
"""
This function is called by the KAIROS infrastructure code for each
TASK1 input.
"""
return predict_kairos(model_archive=model_archive,
source_folder=input_dir,
onto_map=onto_map)
def run():
parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n')
parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.')
parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.')
parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.')
parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)')
args = parser.parse_args()
logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
predictions = predict_kairos(model_archive=args.model_archive,
source_folder=args.source_folder,
onto_map=args.onto_map)
logging.info('Saving to ' + args.destination + ' ...')
os.makedirs(os.path.dirname(args.destination), exist_ok=True)
with open(args.destination, 'w') as fp:
fp.write('\n'.join(map(json.dumps, predictions)))
logging.info('Done.')
if __name__ == '__main__':
run()