import os import argparse from xml.etree import ElementTree import copy from operator import attrgetter import json import logging from sftp import SpanPredictor def predict_kairos(model_archive, source_folder, onto_map): xml_files = list() for root, _, files in os.walk(source_folder): for f in files: if f.endswith('.xml'): xml_files.append(os.path.join(root, f)) logging.info(f'{len(xml_files)} files are found:') for fn in xml_files: logging.info(' - ' + fn) logging.info('Loading ontology from ' + onto_map) k_map = dict() for kairos_event, content in json.load(open(onto_map)).items(): for fr in content['framenet']: if fr['label'] in k_map: logging.info("Duplicate frame: " + fr['label']) k_map[fr['label']] = kairos_event logging.info('Loading model from ' + model_archive + ' ...') predictor = SpanPredictor.from_path(model_archive) predictions = list() for fn in xml_files: logging.info('Now processing ' + os.path.basename(fn)) tree = ElementTree.parse(fn).getroot() for doc in tree: doc_meta = copy.deepcopy(doc.attrib) text = list(doc)[0] for seg in text: seg_meta = copy.deepcopy(doc_meta) seg_meta['seg'] = copy.deepcopy(seg.attrib) tokens = [child for child in seg if child.tag == 'TOKEN'] tokens.sort(key=lambda t: t.attrib['start_char']) words = list(map(attrgetter('text'), tokens)) one_pred = predictor.predict_sentence(words) one_pred['meta'] = seg_meta new_frames = list() for fr in one_pred['prediction']: if fr['label'] in k_map: fr['label'] = k_map[fr['label']] new_frames.append(fr) one_pred['prediction'] = new_frames predictions.append(one_pred) logging.info('Finished Prediction.') return predictions def do_task(input_dir, model_archive, onto_map): """ This function is called by the KAIROS infrastructure code for each TASK1 input. """ return predict_kairos(model_archive=model_archive, source_folder=input_dir, onto_map=onto_map) def run(): parser = argparse.ArgumentParser(description='Span Finder for KAIROS Quizlet4\n') parser.add_argument('model_archive', metavar='MODEL_ARCHIVE', type=str, help='Path to model archive file.') parser.add_argument('source_folder', metavar='SOURCE_FOLDER', type=str, help='Path to the folder that contains the XMLs.') parser.add_argument('onto_map', metavar='ONTO_MAP', type=str, help='Path to the ontology JSON.') parser.add_argument('destination', metavar='DESTINATION', type=str, help='Output path. (jsonl file path)') args = parser.parse_args() logging.basicConfig(level='INFO', format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s") predictions = predict_kairos(model_archive=args.model_archive, source_folder=args.source_folder, onto_map=args.onto_map) logging.info('Saving to ' + args.destination + ' ...') os.makedirs(os.path.dirname(args.destination), exist_ok=True) with open(args.destination, 'w') as fp: fp.write('\n'.join(map(json.dumps, predictions))) logging.info('Done.') if __name__ == '__main__': run()