Spaces:
Build error
Build error
import json | |
import re | |
import spacy | |
from tqdm import tqdm | |
from src.genie.utils import WhitespaceTokenizer | |
#x = 0 | |
def find_head(arg_start, arg_end, doc): | |
# 设置一个临时变量 存储 论元短语的开始索引 cur_i = arg_start | |
cur_i = arg_start | |
# 进行遍历 | |
while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <= arg_end: | |
if doc[cur_i].head.i == cur_i: | |
# self is the head | |
break | |
else: | |
cur_i = doc[cur_i].head.i | |
arg_head = cur_i | |
return (arg_head, arg_head) | |
def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None): | |
# 要定义一个match 作为匹配项 | |
match = None | |
# arg 是论元短语 是预测文件中predicted中生成的论元短语 arg_len目前的含义是获取生成论元短语的长度 | |
arg_len = len(arg) | |
# context_words 是文本 min_dis是最短距离 | |
min_dis = len(context_words) # minimum distance to trigger | |
#print(arg) | |
#x = 0 | |
# i 代表文本中的单词索引 w 代表文本中的i索引对应的单词 | |
for i, w in enumerate(context_words): | |
# 如果文本单词列表中有一段单词 和 模型生成的单词是相等的 | |
if context_words[i:i + arg_len] == arg: | |
# 如果 这个论元单词的开始索引在触发词单词索引之前 | |
# global x += 1 | |
# print('aa') | |
if i < trigger_start: | |
# 那么距离就是 触发词单词的开始索引减去论元短语的开始索引再减去论元短语的长度 | |
dis = abs(trigger_start - i - arg_len) | |
else: | |
# 反之 | |
dis = abs(i - trigger_end) | |
if dis < min_dis: | |
# match是一个元组 | |
match = (i, i + arg_len - 1) | |
min_dis = dis | |
#print(match) | |
if match and head_only: | |
assert (doc != None) | |
match = find_head(match[0], match[1], doc) | |
#print(x) | |
return match | |
def get_event_type(ex): | |
evt_type = [] | |
for evt in ex['evt_triggers']: | |
for t in evt[2]: | |
evt_type.append(t[0]) | |
return evt_type | |
def extract_args_from_template(ex, template, ontology_dict,): | |
# extract argument text | |
# 这个函数的返回值是一个字典 因此需要 template列表和ex中的predicted列表同时进行遍历放入字典中 | |
# 在这里定义两个列表 分别存放 定义存放模板的列表 TEMPLATE 和 相对应的生成 PREDICTED | |
# 传过来的参数中的template就是包含所有模板的列表 因此不需要再定义TEMPLATE 还是需要定义一个存放分词后的template | |
# 这里的template是相应事件类型下的模板包含多个 | |
# 原来处理的方式是一个数据和一个综合性模板 现在模板是分开的 为什么要把template传过来 这不是脱裤子放屁的操作? | |
# 下面这段操作是因为上次模板的定义是相同因此只需要去列表中的第一个模板就行 这次需要用循环进行遍历 | |
# print(ex) | |
t = [] | |
TEMPLATE = [] | |
for i in template: | |
t = i.strip().split() | |
TEMPLATE.append(t) | |
t = [] | |
# 到此为止 得到存放该ex即该数据类型下的所有模板的分词后的列表存储 下面获取对应的predicted同理 | |
PREDICTED = [] | |
p = [] | |
# 形参中插入的ex应该包含了该条数据(即该事件类型下)所有应该生成的论元对应的模板 | |
# 在程序中出现了不一样的情况 貌似只有一条模板数据 这个问题解决了 | |
# print(ex['predicted']) | |
for i in ex['predicted']: | |
p = i.strip().split() | |
PREDICTED.append(p) | |
p = [] | |
# print(TEMPLATE) | |
# print(PREDICTED) | |
# 这个字典变量定义了这个函数的返回值 应该是论元角色-论元短语的key-value映射 | |
predicted_args = {} | |
evt_type = get_event_type(ex)[0] | |
# print(evt_type) | |
# 不出意外的话 TEMPLATE和PREDICTED的长度应该是相等的 | |
length = len(TEMPLATE) | |
for i in range(length): | |
#if i < 4: | |
#continue | |
template_words = TEMPLATE[i] | |
predicted_words = PREDICTED[i] | |
t_ptr = 0 | |
p_ptr = 0 | |
print(template_words) | |
print(predicted_words) | |
while t_ptr < len(template_words) and p_ptr < len(predicted_words): | |
if re.match(r'<(arg\d+)>', template_words[t_ptr]): | |
# print('aa') | |
m = re.match(r'<(arg\d+)>', template_words[t_ptr]) | |
# 这一步的操作是从模板中得到 <arg1> 这样的词符 即arg_num 然后通过arg_num找到对应论元角色arg_name | |
arg_num = m.group(1) | |
# print(arg_num) | |
arg_name = ontology_dict[evt_type.replace('n/a', 'unspecified')][arg_num] | |
if predicted_words[p_ptr] == '<arg>': | |
# missing argument | |
p_ptr +=1 | |
t_ptr +=1 | |
else: | |
arg_start = p_ptr | |
if t_ptr + 1 == len(template_words): | |
while (p_ptr < len(predicted_words)): | |
p_ptr += 1 | |
else: | |
while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]): | |
p_ptr += 1 | |
arg_text = predicted_words[arg_start:p_ptr] | |
predicted_args[arg_name] = arg_text | |
t_ptr += 1 | |
# aligned | |
else: | |
t_ptr += 1 | |
p_ptr += 1 | |
# print(predicted_args) | |
return predicted_args | |
def pro(): | |
nlp = spacy.load('en_core_web_sm') | |
nlp.tokenizer = WhitespaceTokenizer(nlp.vocab) | |
ontology_dict = {} | |
with open('./aida_ontology_fj-5.csv', 'r') as f: | |
for lidx, line in enumerate(f): | |
if lidx == 0: # header | |
continue | |
fields = line.strip().split(',') | |
if len(fields) < 2: | |
break | |
evt_type = fields[0] | |
if evt_type in ontology_dict.keys(): | |
arguments = fields[2:] | |
ontology_dict[evt_type]['template'].append(fields[1]) | |
for i, arg in enumerate(arguments): | |
if arg != '': | |
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg | |
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) | |
else: | |
ontology_dict[evt_type] = {} | |
arguments = fields[2:] | |
ontology_dict[evt_type]['template'] = [] | |
ontology_dict[evt_type]['template'].append(fields[1]) | |
for i, arg in enumerate(arguments): | |
if arg != '': | |
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg | |
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) | |
examples = {} | |
x = 0 | |
with open('./data/RAMS_1.0/data/test_head_coref.jsonlines', 'r') as f: | |
for line in f: | |
x += 1 | |
ex = json.loads(line.strip()) | |
ex['ref_evt_links'] = ex['gold_evt_links'] | |
ex['gold_evt_links'] = [] | |
examples[ex['doc_key']] = ex | |
flag = {} | |
y = 0 | |
with open('./checkpoints/gen-RAMS-pred/predictions.jsonl', 'r') as f: | |
for line in f: | |
y += 1 | |
pred = json.loads(line.strip()) | |
# print(pred['predicted']) | |
if pred['doc_key'] in flag.keys(): | |
examples[pred['doc_key']]['predicted'].append(pred['predicted']) | |
examples[pred['doc_key']]['gold'].append(pred['gold']) | |
# 如果没有 说明这是新的事件类型 | |
else: | |
flag[pred['doc_key']] = True | |
examples[pred['doc_key']]['predicted'] = [] | |
examples[pred['doc_key']]['gold'] = [] | |
# 然后将此条数据存入 | |
examples[pred['doc_key']]['predicted'].append(pred['predicted']) | |
examples[pred['doc_key']]['gold'].append(pred['gold']) | |
# print(len(examples), x, y) 871 871 3614 | |
for ex in tqdm(examples.values()): | |
if 'predicted' not in ex:# this is used for testing | |
continue | |
# print(ex) | |
# break | |
# print(ex) | |
# get template 获取事件类型 | |
# print('nw_RC00c8620ef5810429342a1c339e6c76c1b0b9add3f6010f04482fd832') | |
evt_type = get_event_type(ex)[0] | |
context_words = [w for sent in ex['sentences'] for w in sent] | |
# 这里的template是ontology_dict中 template 包含一个事件类型下的所有事件模板 | |
template = ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] | |
# extract argument text | |
# 这里应该是提取预测文件中预测到的论元短语 ex是一条json数据 template是这条json数据对应下的模板 on是论元角色和<arg1>的映射 | |
# 这里ex中的predicted和gold已经包括了该事件类型下的所有论元 用列表的形式进行存储 且顺序是一一对应的 | |
# 这里返回的predicted_args是一个字典: | |
# ex = {'predicted': [' A man attacked target using something at place in order to take something', ' Attacker attacked EgyptAir plane using something at place in order to take something', ' Attacker attacked target using a suicide belt at place in order to take something', ' Attacker attacked target using something at Flight 181 place in order to take something', ' Attacker attacked target using something at place in order to take EgyptAir Flight 181']} | |
# template = ontology_dict['conflict.attack.stealrobhijack']['template'] | |
# print(ex) | |
predicted_args = extract_args_from_template(ex, template, ontology_dict) | |
# print(predicted_args) | |
# break | |
trigger_start = ex['evt_triggers'][0][0] | |
trigger_end = ex['evt_triggers'][0][1] | |
# 上面返回的predicted_args是一个字典 暂时认为是论元角色和具体论元短语的映射 | |
# 还没有发现doc的作用 | |
doc = None | |
# 通过test_rams.sh文件的设置 可以发现args.head_only的值为true | |
head_only = True | |
if head_only: | |
# # 从原始文本中取出标记 | |
doc = nlp(' '.join(context_words)) | |
for argname in predicted_args: | |
# 通过find_arg_span函数找出 | |
arg_span = find_arg_span(predicted_args[argname], context_words, | |
trigger_start, trigger_end, head_only=True, doc=doc) | |
# print() | |
#print(arg_span) | |
pro() | |
#print(x) | |
# dict = {'A': 1, 'B': 2, 'C': 3} | |
# | |
# for x in dict: | |
# print(x) | |
# if '1' in dict.keys(): | |
# print('aaaaaaaa') | |