File size: 2,004 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from argparse import ArgumentParser
import hashlib
import os

from sftp.data_reader import BetterDatasetReader, ConcreteDatasetReader
from tools.ontology_mapping.force_map import ontology_map, read_framenet


def read_ace_better(reader, data_path):
    sentences = list()
    for ins in reader.read(data_path):
        sentences.append(tuple(ins.fields['raw_inputs'].metadata[key] for key in ['sentence', 'spans']))
    return sentences


def run(model_path, src_data_path, tgt_data_path, device, dst_path):
    if model_path.endswith('.tar.gz'):
        model_md5 = hashlib.md5(open(model_path, 'rb').read()).hexdigest()
    else:
        model_md5 = hashlib.md5(open(os.path.join(model_path, 'model.tar.gz'), 'rb').read()).hexdigest()
    print('model md5: ', model_md5)
    if 'better' in tgt_data_path.lower():
        reader = BetterDatasetReader(eval_type='basic', pretrained_model='roberta-large', ignore_label=False)
    elif 'ace' in tgt_data_path.lower():
        reader = ConcreteDatasetReader(ignore_unlabeled_sentence=True, pretrained_model='roberta-large')
    else:
        raise NotImplementedError
    meta = {
        'model': {'path': model_path, 'md5': model_md5},
        'src_data_path': src_data_path,
        'tgt_data_path': tgt_data_path
    }
    # event_list and arg_list are target ontology
    # label_list is source ontology (i.e. FrameNet)
    src_data, tgt_data = read_framenet(src_data_path), read_ace_better(reader, tgt_data_path)
    ontology_map(model_path, src_data, tgt_data, device, dst_path, meta)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('model', metavar='MODEL_PATH')
    parser.add_argument('src', metavar='SRC_DATA_PATH')
    parser.add_argument('tgt', metavar='TGT_DATA_PATH')
    parser.add_argument('dst', metavar='DESTINATION_PATH')
    parser.add_argument('-d', type=int, help='device', default=-1)
    cmd_args = parser.parse_args()
    run(cmd_args.model, cmd_args.src, cmd_args.tgt, cmd_args.d, cmd_args.dst)