Bart-gen-arg / TEST /ttest2.py
adherent's picture
new
44a9d55
import json
import re
import spacy
from tqdm import tqdm
from src.genie.utils import WhitespaceTokenizer
#x = 0
def find_head(arg_start, arg_end, doc):
# 设置一个临时变量 存储 论元短语的开始索引 cur_i = arg_start
cur_i = arg_start
# 进行遍历
while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <= arg_end:
if doc[cur_i].head.i == cur_i:
# self is the head
break
else:
cur_i = doc[cur_i].head.i
arg_head = cur_i
return (arg_head, arg_head)
def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None):
# 要定义一个match 作为匹配项
match = None
# arg 是论元短语 是预测文件中predicted中生成的论元短语 arg_len目前的含义是获取生成论元短语的长度
arg_len = len(arg)
# context_words 是文本 min_dis是最短距离
min_dis = len(context_words) # minimum distance to trigger
#print(arg)
#x = 0
# i 代表文本中的单词索引 w 代表文本中的i索引对应的单词
for i, w in enumerate(context_words):
# 如果文本单词列表中有一段单词 和 模型生成的单词是相等的
if context_words[i:i + arg_len] == arg:
# 如果 这个论元单词的开始索引在触发词单词索引之前
# global x += 1
# print('aa')
if i < trigger_start:
# 那么距离就是 触发词单词的开始索引减去论元短语的开始索引再减去论元短语的长度
dis = abs(trigger_start - i - arg_len)
else:
# 反之
dis = abs(i - trigger_end)
if dis < min_dis:
# match是一个元组
match = (i, i + arg_len - 1)
min_dis = dis
#print(match)
if match and head_only:
assert (doc != None)
match = find_head(match[0], match[1], doc)
#print(x)
return match
def get_event_type(ex):
evt_type = []
for evt in ex['evt_triggers']:
for t in evt[2]:
evt_type.append(t[0])
return evt_type
def extract_args_from_template(ex, template, ontology_dict,):
# extract argument text
# 这个函数的返回值是一个字典 因此需要 template列表和ex中的predicted列表同时进行遍历放入字典中
# 在这里定义两个列表 分别存放 定义存放模板的列表 TEMPLATE 和 相对应的生成 PREDICTED
# 传过来的参数中的template就是包含所有模板的列表 因此不需要再定义TEMPLATE 还是需要定义一个存放分词后的template
# 这里的template是相应事件类型下的模板包含多个
# 原来处理的方式是一个数据和一个综合性模板 现在模板是分开的 为什么要把template传过来 这不是脱裤子放屁的操作?
# 下面这段操作是因为上次模板的定义是相同因此只需要去列表中的第一个模板就行 这次需要用循环进行遍历
# print(ex)
t = []
TEMPLATE = []
for i in template:
t = i.strip().split()
TEMPLATE.append(t)
t = []
# 到此为止 得到存放该ex即该数据类型下的所有模板的分词后的列表存储 下面获取对应的predicted同理
PREDICTED = []
p = []
# 形参中插入的ex应该包含了该条数据(即该事件类型下)所有应该生成的论元对应的模板
# 在程序中出现了不一样的情况 貌似只有一条模板数据 这个问题解决了
# print(ex['predicted'])
for i in ex['predicted']:
p = i.strip().split()
PREDICTED.append(p)
p = []
# print(TEMPLATE)
# print(PREDICTED)
# 这个字典变量定义了这个函数的返回值 应该是论元角色-论元短语的key-value映射
predicted_args = {}
evt_type = get_event_type(ex)[0]
# print(evt_type)
# 不出意外的话 TEMPLATE和PREDICTED的长度应该是相等的
length = len(TEMPLATE)
for i in range(length):
#if i < 4:
#continue
template_words = TEMPLATE[i]
predicted_words = PREDICTED[i]
t_ptr = 0
p_ptr = 0
print(template_words)
print(predicted_words)
while t_ptr < len(template_words) and p_ptr < len(predicted_words):
if re.match(r'<(arg\d+)>', template_words[t_ptr]):
# print('aa')
m = re.match(r'<(arg\d+)>', template_words[t_ptr])
# 这一步的操作是从模板中得到 <arg1> 这样的词符 即arg_num 然后通过arg_num找到对应论元角色arg_name
arg_num = m.group(1)
# print(arg_num)
arg_name = ontology_dict[evt_type.replace('n/a', 'unspecified')][arg_num]
if predicted_words[p_ptr] == '<arg>':
# missing argument
p_ptr +=1
t_ptr +=1
else:
arg_start = p_ptr
if t_ptr + 1 == len(template_words):
while (p_ptr < len(predicted_words)):
p_ptr += 1
else:
while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]):
p_ptr += 1
arg_text = predicted_words[arg_start:p_ptr]
predicted_args[arg_name] = arg_text
t_ptr += 1
# aligned
else:
t_ptr += 1
p_ptr += 1
# print(predicted_args)
return predicted_args
def pro():
nlp = spacy.load('en_core_web_sm')
nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
ontology_dict = {}
with open('./aida_ontology_fj-5.csv', 'r') as f:
for lidx, line in enumerate(f):
if lidx == 0: # header
continue
fields = line.strip().split(',')
if len(fields) < 2:
break
evt_type = fields[0]
if evt_type in ontology_dict.keys():
arguments = fields[2:]
ontology_dict[evt_type]['template'].append(fields[1])
for i, arg in enumerate(arguments):
if arg != '':
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
else:
ontology_dict[evt_type] = {}
arguments = fields[2:]
ontology_dict[evt_type]['template'] = []
ontology_dict[evt_type]['template'].append(fields[1])
for i, arg in enumerate(arguments):
if arg != '':
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
examples = {}
x = 0
with open('./data/RAMS_1.0/data/test_head_coref.jsonlines', 'r') as f:
for line in f:
x += 1
ex = json.loads(line.strip())
ex['ref_evt_links'] = ex['gold_evt_links']
ex['gold_evt_links'] = []
examples[ex['doc_key']] = ex
flag = {}
y = 0
with open('./checkpoints/gen-RAMS-pred/predictions.jsonl', 'r') as f:
for line in f:
y += 1
pred = json.loads(line.strip())
# print(pred['predicted'])
if pred['doc_key'] in flag.keys():
examples[pred['doc_key']]['predicted'].append(pred['predicted'])
examples[pred['doc_key']]['gold'].append(pred['gold'])
# 如果没有 说明这是新的事件类型
else:
flag[pred['doc_key']] = True
examples[pred['doc_key']]['predicted'] = []
examples[pred['doc_key']]['gold'] = []
# 然后将此条数据存入
examples[pred['doc_key']]['predicted'].append(pred['predicted'])
examples[pred['doc_key']]['gold'].append(pred['gold'])
# print(len(examples), x, y) 871 871 3614
for ex in tqdm(examples.values()):
if 'predicted' not in ex:# this is used for testing
continue
# print(ex)
# break
# print(ex)
# get template 获取事件类型
# print('nw_RC00c8620ef5810429342a1c339e6c76c1b0b9add3f6010f04482fd832')
evt_type = get_event_type(ex)[0]
context_words = [w for sent in ex['sentences'] for w in sent]
# 这里的template是ontology_dict中 template 包含一个事件类型下的所有事件模板
template = ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']
# extract argument text
# 这里应该是提取预测文件中预测到的论元短语 ex是一条json数据 template是这条json数据对应下的模板 on是论元角色和<arg1>的映射
# 这里ex中的predicted和gold已经包括了该事件类型下的所有论元 用列表的形式进行存储 且顺序是一一对应的
# 这里返回的predicted_args是一个字典:
# ex = {'predicted': [' A man attacked target using something at place in order to take something', ' Attacker attacked EgyptAir plane using something at place in order to take something', ' Attacker attacked target using a suicide belt at place in order to take something', ' Attacker attacked target using something at Flight 181 place in order to take something', ' Attacker attacked target using something at place in order to take EgyptAir Flight 181']}
# template = ontology_dict['conflict.attack.stealrobhijack']['template']
# print(ex)
predicted_args = extract_args_from_template(ex, template, ontology_dict)
# print(predicted_args)
# break
trigger_start = ex['evt_triggers'][0][0]
trigger_end = ex['evt_triggers'][0][1]
# 上面返回的predicted_args是一个字典 暂时认为是论元角色和具体论元短语的映射
# 还没有发现doc的作用
doc = None
# 通过test_rams.sh文件的设置 可以发现args.head_only的值为true
head_only = True
if head_only:
# # 从原始文本中取出标记
doc = nlp(' '.join(context_words))
for argname in predicted_args:
# 通过find_arg_span函数找出
arg_span = find_arg_span(predicted_args[argname], context_words,
trigger_start, trigger_end, head_only=True, doc=doc)
# print()
#print(arg_span)
pro()
#print(x)
# dict = {'A': 1, 'B': 2, 'C': 3}
#
# for x in dict:
# print(x)
# if '1' in dict.keys():
# print('aaaaaaaa')