""" Adapted from comm2multilabel.py from the Bert-for-FrameNet project (https://gitlab.com/gosseminnema/bert-for-framenet) """ import dataclasses import json import os import glob import sys from collections import defaultdict from typing import List, Optional import nltk from concrete import Communication from concrete.util import read_communication_from_file, lun, get_tokens @dataclasses.dataclass class FrameAnnotation: tokens: List[str] = dataclasses.field(default_factory=list) pos: List[str] = dataclasses.field(default_factory=list) @dataclasses.dataclass class MultiLabelAnnotation(FrameAnnotation): frame_list: List[List[str]] = dataclasses.field(default_factory=list) lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) def to_txt(self): for i, tok in enumerate(self.tokens): yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" @staticmethod def from_txt(sentence_lines): tokens = [] pos = [] frame_list = [] lu_list = [] for line in sentence_lines: # ignore any spaces if line.startswith(" "): continue columns = line.split() tokens.append(columns[0]) pos.append(columns[1]) # read frame list, handle empty lists if columns[2] == "_": frame_list.append([]) else: frame_list.append(columns[2].split("|")) # read lu list, handle nulls if columns[3] == "_": lu_list.append(None) else: lu_list.append(columns[3]) return MultiLabelAnnotation(tokens, pos, frame_list, lu_list) def get_label_set(self): label_set = set() for tok_labels in self.frame_list: for label in tok_labels: label_set.add(label) return label_set def convert_file(file, language="english", confidence_filter=0.0): print("Reading input file...") comm = read_communication_from_file(file) print("Mapping sentences to situations...") tok_uuid_to_situation = map_sent_to_situation(comm) print("# sentences with situations:", len(tok_uuid_to_situation)) for section in lun(comm.sectionList): for sentence in lun(section.sentenceList): tokens = get_tokens(sentence.tokenization) situations = tok_uuid_to_situation[sentence.tokenization.uuid.uuidString] tok_to_annos = map_tokens_to_annotations(comm, situations, confidence_filter) frame_list, tok_list = prepare_ml_lists(language, tok_to_annos, tokens) ml_anno = MultiLabelAnnotation(tok_list, ["_" for _ in tok_list], frame_list, [None for _ in tok_list]) yield ml_anno def prepare_ml_lists(language, tok_to_annos, tokens): tok_list = [] frame_list = [] for tok_idx, tok in enumerate(tokens): # split tokens that include punctuation split_tok = nltk.word_tokenize(tok.text, language=language) tok_list.extend(split_tok) tok_anno = [] for anno in tok_to_annos.get(tok_idx, []): tok_anno.append(anno) frame_list.extend([list(tok_anno) for _ in split_tok]) # remove annotations from final punctuation & solve BIO weird stuff for idx, (tok, frame_annos) in enumerate(zip(tok_list, frame_list)): if tok in ",.:;\"'`«»": to_delete = [] for fa in frame_annos: if fa.startswith("T:"): compare_fa = fa else: compare_fa = "I" + fa[1:] if idx == len(tok_list) - 1: to_delete.append(fa) elif compare_fa not in frame_list[idx + 1]: to_delete.append(fa) for fa in to_delete: frame_annos.remove(fa) for fa_idx, fa in enumerate(frame_annos): if fa.startswith("B:"): # check if we had exactly the same label the token before if idx > 0 and fa in frame_list[idx - 1]: frame_annos[fa_idx] = "I" + fa[1:] return frame_list, tok_list def map_tokens_to_annotations(comm: Communication, situations: List[str], confidence_filter: float): tok_to_annos = defaultdict(list) for sit_idx, sit_uuid in enumerate(situations): situation = comm.situationMentionForUUID[sit_uuid] if situation.confidence < confidence_filter: continue frame_type = situation.situationKind tgt_tokens = situation.tokens.tokenIndexList if frame_type == "@@VIRTUAL_ROOT@@": continue for tok_id in tgt_tokens: tok_to_annos[tok_id].append(f"T:{frame_type}@{sit_idx:02}@@{situation.confidence}") for arg in situation.argumentList: if arg.confidence < confidence_filter: continue fe_type = arg.role fe_tokens = arg.entityMention.tokens.tokenIndexList for tok_n, tok_id in enumerate(fe_tokens): if tok_n == 0: bio = "B" else: bio = "I" tok_to_annos[tok_id].append(f"{bio}:{frame_type}:{fe_type}@{sit_idx:02}@@{arg.confidence}") return tok_to_annos def map_sent_to_situation(comm): tok_uuid_to_situation = defaultdict(list) for situation in comm.situationMentionSetList: for mention in situation.mentionList: tok_uuid_to_situation[mention.tokens.tokenizationId.uuidString].append(mention.uuid.uuidString) return tok_uuid_to_situation def main(): file_in = sys.argv[1] language = sys.argv[2] output_directory = sys.argv[3] confidence_filter = float(sys.argv[4]) split_by_migration_files = False file_in_base = os.path.basename(file_in) file_out = f"{output_directory}/lome_{file_in_base}" multi_label_annos = list(convert_file(file_in, language=language, confidence_filter=confidence_filter)) multi_label_json = [dataclasses.asdict(anno) for anno in multi_label_annos] if split_by_migration_files: files = glob.glob("output/migration/split_data/split_dev10_sep_txt_files/*.orig.txt") files.sort(key=lambda f: int(f.split("/")[-1].rstrip(".orig.txt"))) for anno, file in zip(multi_label_annos, files): basename = file.split("/")[-1].rstrip(".orig.txt") spl_file_out = f"{output_directory}/{basename}" with open(f"{spl_file_out}.txt", "w", encoding="utf-8") as f_txt: for line in anno.to_txt(): f_txt.write(line + os.linesep) f_txt.write(os.linesep) else: print(file_out) with open(f"{file_out}.json", "w", encoding="utf-8") as f_json: json.dump(multi_label_json, f_json, indent=4) with open(f"{file_out}.txt", "w", encoding="utf-8") as f_txt: for anno in multi_label_annos: for line in anno.to_txt(): f_txt.write(line + os.linesep) f_txt.write(os.linesep) if __name__ == '__main__': main()