from argparse import ArgumentParser from collections import defaultdict from torch import nn from copy import deepcopy import torch import os import json from sftp import SpanPredictor import nltk def shift_grid_cos_sim(mat: torch.Tensor): mat1 = mat.unsqueeze(0).expand(mat.shape[0], -1, -1) mat2 = mat.unsqueeze(1).expand(-1, mat.shape[0], -1) cos = nn.CosineSimilarity(2) sim = (cos(mat1, mat2) + 1) / 2 return sim def all_frames(): nltk.download('framenet_v17') fn = nltk.corpus.framenet return fn.frames() def extract_relations(fr): ret = list() added = {fr.name} for rel in fr.frameRelations: for key in ['subFrameName', 'superFrameName']: rel_fr_name = rel[key] if rel_fr_name in added: continue ret.append((rel_fr_name, key[:-4])) return ret def run(): parser = ArgumentParser() parser.add_argument('archive', metavar='ARCHIVE_PATH', type=str) parser.add_argument('dst', metavar='DESTINATION', type=str) parser.add_argument('kairos', metavar='KAIROS', type=str) parser.add_argument('--topk', metavar='TOPK', type=int, default=10) args = parser.parse_args() predictor = SpanPredictor.from_path(args.archive, cuda_device=-1) kairos_gold_mapping = json.load(open(args.kairos)) label_emb = predictor._model._span_typing.label_emb.weight.clone().detach() idx2label = predictor._model.vocab.get_index_to_token_vocabulary('span_label') emb_sim = shift_grid_cos_sim(label_emb) fr2definition = {fr.name: (fr.URL, fr.definition) for fr in all_frames()} last_mlp = predictor._model._span_typing.MLPs[-1].weight.detach().clone() mlp_sim = shift_grid_cos_sim(last_mlp) def rank_frame(sim): rank = sim.argsort(1, True) scores = sim.gather(1, rank) mapping = { fr.name: { 'similarity': list(), 'ontology': extract_relations(fr), 'URL': fr.URL, 'definition': fr.definition } for fr in all_frames() } for left_idx, (right_indices, match_scores) in enumerate(zip(rank, scores)): left_label = idx2label[left_idx] if left_label not in mapping: continue for right_idx, s in zip(right_indices, match_scores): right_label = idx2label[int(right_idx)] if right_label not in mapping or right_idx == left_idx: continue mapping[left_label]['similarity'].append((right_label, float(s))) return mapping emb_map = rank_frame(emb_sim) mlp_map = rank_frame(mlp_sim) def dump(mapping, folder_path): os.makedirs(folder_path, exist_ok=True) json.dump(mapping, open(os.path.join(folder_path, 'raw.json'), 'w')) sim_lines, onto_lines = list(), list() for fr, values in mapping.items(): sim_line = [ fr, values['definition'], values['URL'], ] onto_line = deepcopy(sim_line) for rel_fr_name, rel_type in values['ontology']: onto_line.append(f'{rel_fr_name} ({rel_type})') onto_lines.append('\t'.join(onto_line)) if len(values['similarity']) > 0: for sim_fr_name, score in values['similarity'][:args.topk]: sim_line.append(f'{sim_fr_name} ({score:.3f})') sim_lines.append('\t'.join(sim_line)) with open(os.path.join(folder_path, 'similarity.tsv'), 'w') as fp: fp.write('\n'.join(sim_lines)) with open(os.path.join(folder_path, 'ontology.tsv'), 'w') as fp: fp.write('\n'.join(onto_lines)) kairos_dump = list() for kairos_event, kairos_content in kairos_gold_mapping.items(): for gold_fr in kairos_content['framenet']: gold_fr = gold_fr['label'] if gold_fr not in fr2definition: continue kairos_dump.append([ 'GOLD', gold_fr, kairos_event, fr2definition[gold_fr][0], fr2definition[gold_fr][1], str(kairos_content['description']), '1.00' ]) for ass_fr, sim_score in mapping[gold_fr]['similarity'][:args.topk]: kairos_dump.append([ '', ass_fr, kairos_event, fr2definition[ass_fr][0], fr2definition[ass_fr][1], str(kairos_content['description']), f'{sim_score:.2f}' ]) kairos_dump = list(map(lambda line: '\t'.join(line), kairos_dump)) open(os.path.join(folder_path, 'kairos_sheet.tsv'), 'w').write('\n'.join(kairos_dump)) dump(mlp_map, os.path.join(args.dst, 'mlp')) dump(emb_map, os.path.join(args.dst, 'emb')) if __name__ == '__main__': run()