Bart-gen-arg / src /runner.py
adherent's picture
Add application file
4bb803b
import difflib
import os
import json
from tqdm import tqdm
from glob import glob
#
# if not os.path.exists('./evttgr2type.json'):
# for file_name in glob('data/RAMS_1.0/data/test.jsonlines'):
# dic = {}
# with open(file_name,'r',encoding='utf-8') as f:
# lines = f.readlines()
# for line in tqdm(lines):
# linej = json.loads(line.strip())
# evt_triggers = linej['evt_triggers']
# # print(evt_triggers)
# sentences = linej['sentences']
# # print(sentences)
# sentences_uni = []
# for s in sentences:
# sentences_uni += s
# print(' '.join(sentences_uni))
# triggers = ' '.join(sentences_uni[evt_triggers[0][0]:evt_triggers[0][1]+1])
# evt_type = evt_triggers[0][2][0][0]
# if triggers in dic:
# if dic[triggers] != evt_type:
# print('一个触发词有不同的事件类型: {} {} {}'.format(triggers,evt_type,dic[triggers]))
# dic[triggers] = evt_type
# print(evt_type, triggers)
# exit()
import argparse
import jsonlines
import torch
from src.genie.data import my_collate
from src.genie.data_module_w import RAMSDataModule
from src.genie.model import GenIEModel
import gradio as gr
import re
from transformers import BartTokenizer
MAX_LENGTH = 424
MAX_TGT_LENGTH = 72
DOC_STRIDE = 256
class DataModule4():
def __init__(self, ontology_file):
super().__init__()
self.ontology_file = ontology_file
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
self.tokenizer.add_tokens([' <arg>', ' <tgr>'])
self.ontology_dict = self.load_ontology()
def create_gold_gen(self, context_words, evt_type, trigger):
# 设置三个总列表、存放输入模板、输出模板
INPUT = []
CONTEXT = []
input_template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']
i = len(input_template)
input_list = []
for x in range(i):
str = re.sub(r'<arg\d>', '<arg>', input_template[x])
input_list.append(str)
# 其中input_list种存放的是 原始数据中<arg1> 全部替换为 <arg> 之后的模板 下一步应该进行分词
temp = []
for x in range(i):
space_tokenized_template = input_list[x].split(' ')
temp.append(space_tokenized_template)
# 其中temp中存放的都是分词后的模板 下一步对temp中的所有元素进行tokenize
tokenized_input_template = []
for x in range(len(temp)):
for w in temp[x]:
tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True))
INPUT.append(tokenized_input_template)
tokenized_input_template = []
context_words = context_words.split(' ')
trigger_words = trigger.split(' ')
trigger_span_start = context_words.index(trigger_words[0])
trigger_span_end = context_words.index(trigger_words[-1])
# 触发词之前的单词
prefix = self.tokenizer.tokenize(' '.join(context_words[:trigger_span_start]), add_prefix_space=True)
# 触发词短语
tgt = self.tokenizer.tokenize(trigger, add_prefix_space=True)
# 触发词之后的单词
suffix = self.tokenizer.tokenize(' '.join(context_words[trigger_span_end+1:]), add_prefix_space=True)
context = prefix + [' <tgr>', ] + tgt + [' <tgr>', ] + suffix
# context = self.tokenizer.tokenize(' '.join(context_words), add_prefix_space=True)
# 将context放入CONTEXT中
for w in range(i):
CONTEXT.append(context)
return INPUT, CONTEXT
def load_ontology(self):
ontology_dict = {}
with open(self.ontology_file, 'r') as f:
for lidx, line in enumerate(f):
if lidx == 0: # header
continue
fields = line.strip().split(',')
if len(fields) < 2:
break
evt_type = fields[0]
if evt_type in ontology_dict.keys():
args = fields[2:]
ontology_dict[evt_type]['template'].append(fields[1])
for i, arg in enumerate(args):
if arg != '':
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
else:
ontology_dict[evt_type] = {}
args = fields[2:]
ontology_dict[evt_type]['template'] = []
ontology_dict[evt_type]['template'].append(fields[1])
for i, arg in enumerate(args):
if arg != '':
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
return ontology_dict
def prepare_data(self, sentences, evt_type, trigger):
input_template, context = self.create_gold_gen(sentences, evt_type, trigger)
length = len(input_template)
# print(input_template)
# print(context)
results = []
for i in range(length):
input_tokens = self.tokenizer.encode_plus(input_template[i], context[i],
add_special_tokens=True,
add_prefix_space=True,
max_length=MAX_LENGTH,
truncation='only_second',
padding='max_length')
# input_ids 单词在词典中的编码
results.append(input_tokens['input_ids'])
temp = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]
return results, temp
class DataModuleW():
def __init__(self, ontology_file):
super().__init__()
self.ontology_file = ontology_file
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
self.tokenizer.add_tokens([' <arg>', ' <tgr>'])
self.ontology_dict = self.load_ontology()
def create_gold_gen(self, context_words, evt_type, trigger):
# 设置三个总列表、存放输入模板、输出模板
INPUT = []
CONTEXT = []
input_template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']
i = len(input_template)
input_list = []
for x in range(i):
str = re.sub('<trg>', trigger, input_template[x])
str = re.sub('<trg>', trigger, str)
input_list.append(str)
# 其中input_list种存放的是 原始数据中<arg1> 全部替换为 <arg> 之后的模板 下一步应该进行分词
temp = []
for x in range(i):
space_tokenized_template = input_list[x].split(' ')
temp.append(space_tokenized_template)
# 其中temp中存放的都是分词后的模板 下一步对temp中的所有元素进行tokenize
tokenized_input_template = []
for x in range(len(temp)):
for w in temp[x]:
tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True))
INPUT.append(tokenized_input_template)
tokenized_input_template = []
template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']
for y in range(len(template)):
template[y] = re.sub('<trg>', trigger, template[y])
context = self.tokenizer.tokenize(context_words, add_prefix_space=True)
# 将context放入CONTEXT中
for w in range(i):
CONTEXT.append(context)
return INPUT, CONTEXT
def load_ontology(self):
ontology_dict = {}
with open(self.ontology_file, 'r') as f:
for lidx, line in tqdm(enumerate(f)):
if lidx == 0: # header
continue
fields = line.strip().split(',')
if len(fields) < 2:
break
evt_type = fields[0]
if evt_type in ontology_dict.keys():
args = fields[2:]
ontology_dict[evt_type]['template'].append(fields[1])
for i, arg in enumerate(args):
if arg != '':
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
else:
ontology_dict[evt_type] = {}
args = fields[2:]
ontology_dict[evt_type]['template'] = []
ontology_dict[evt_type]['template'].append(fields[1])
for i, arg in enumerate(args):
if arg != '':
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
return ontology_dict
def prepare_data(self, sentences, evt_type, trigger):
input_template, context = self.create_gold_gen(sentences, evt_type, trigger)
length = len(input_template)
# print(input_template)
# print(output_template)
# print(context)
results = []
for i in range(length):
input_tokens = self.tokenizer.encode_plus(input_template[i], context[i],
add_special_tokens=True,
add_prefix_space=True,
max_length=MAX_LENGTH,
truncation='only_second',
padding='max_length')
# input_ids 单词在词典中的编码
results.append(input_tokens['input_ids'])
temp = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]
return results, temp
class Runner():
def __init__(self, load_ckpt = 'checkpoints/gen-RAMS-what-new-span/epoch=2-v0.ckpt'):
model = 'gen'
self.ckpt_name = 'gen-RAMS-pred'
self.load_ckpt = load_ckpt
self.dataset = 'RAMS'
self.eval_only = True
self.train_file = 'data/RAMS_1.0/data/train.jsonlines'
self.val_file = 'data/RAMS_1.0/data/dev.jsonlines'
self.test_file = 'data/RAMS_1.0/data/test.jsonlines'
self.train_batch_size = 2
self.eval_batch_size = 4
self.learning_rate = 3e-5
self.accumulate_grad_batches = 4
self.num_train_epochs = 3
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model",
type=str,
default=model
)
parser.add_argument(
"--dataset",
type=str,
default=self.dataset
)
parser.add_argument('--tmp_dir', type=str)
parser.add_argument(
"--ckpt_name",
default=self.ckpt_name,
type=str,
help="The output directory where the model checkpoints and predictions will be written.",
)
parser.add_argument(
"--load_ckpt",
default=self.load_ckpt,
type=str,
)
parser.add_argument(
"--train_file",
default=self.train_file,
type=str,
help="The input training file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
"--val_file",
default=self.val_file,
type=str,
help="The input evaluation file. If a data dir is specified, will look for the file there"
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
)
parser.add_argument(
'--test_file',
type=str,
default=self.test_file,
)
parser.add_argument('--input_dir', type=str, default=None)
parser.add_argument('--coref_dir', type=str, default='data/kairos/coref_outputs')
parser.add_argument('--use_info', action='store_true', default=False,
help='use informative mentions instead of the nearest mention.')
parser.add_argument('--mark_trigger', action='store_true')
parser.add_argument('--sample-gen', action='store_true', help='Do sampling when generation.')
parser.add_argument("--train_batch_size", default=self.train_batch_size, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument(
"--eval_batch_size", default=self.eval_batch_size, type=int, help="Batch size per GPU/CPU for evaluation."
)
parser.add_argument("--learning_rate", default=self.learning_rate, type=float,
help="The initial learning rate for Adam.")
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=self.accumulate_grad_batches,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs", default=self.num_train_epochs, type=int,
help="Total number of training epochs to perform."
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--gpus", default=None, help='-1 means train on all gpus')
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument("--threads", type=int, default=1,
help="multiple threads for converting example to features")
self.args = parser.parse_args()
self.model = GenIEModel(self.args)
self.model.load_state_dict(torch.load(self.args.load_ckpt, map_location=self.model.device)['state_dict'])
def pred(self,input):
x = torch.stack([torch.LongTensor(u) for u in input])
return self.model.pred(x)
print('Loading data...')
dm1 = DataModule4('aida_ontology_cleaned.csv')
dm2 = DataModuleW('aida_ontology_fj-w-2.csv')
dm3 = DataModuleW('aida_ontology_fj-w-3.csv')
dm4 = DataModule4('aida_ontology_fj-5.csv')
print('Loading Model 1...')
runner1 = Runner('checkpoints/gen-RAMS-1-span/epoch=2-v1.ckpt')
print('Loading Model 2...')
runner2 = Runner('checkpoints/gen-RAMS-2-span/epoch=2-v0.ckpt')
print('Loading Model 3...')
runner3 = Runner('checkpoints/gen-RAMS-3-span/epoch=2-v0.ckpt')
print('Loading Model 4...')
runner4 = Runner('checkpoints/gen-RAMS-4-span/epoch=2-v0.ckpt')
def handle(sentences,trigger, temp=3, evt_type='contact.prevarication.broadcast'):
x, argnames = eval('dm{}.prepare_data(sentences,evt_type,trigger)'.format(temp+1))
ys = eval('runner{}.pred(x)'.format(temp+1))
print(ys)
results = []
for y in ys:
while ' ' in y:
y = y.replace(' ', ' ')
result = y.strip(' ').split(' ')
results.append(result)
print(results)
argss = []
for n,template in enumerate(argnames['template']):
template = template.split(' ')
# print(template)
args = []
for i, w in enumerate(template):
if '<arg' in w:
m = re.match(r'evt\d+arg\d+(\w+)', argnames[re.match(r'<(\w+)>', w).group(1)])
if m:
label = m.group(1)
if results[n][i] == '<arg>':
args.append(label+': None')
else:
args.append(label+': '+results[n][i])
argss.append(', '.join(args))
return '\n'.join(argss)
if __name__ == "__main__":
# trigger = 'deceive'
# sentences = """We are ashamed of them . " However , Mutko stopped short of admitting the doping scandal was state sponsored . " We are very sorry that athletes who tried to deceive us , and the world , were not caught sooner . We are very sorry because Russia is committed to upholding the highest standards in sport and is opposed to anything that threatens the Olympic values , " he said . English former heptathlete and Athens 2004 bronze medallist Kelly Sotherton was unhappy with Mutko 's plea for Russia 's ban to be lifted for Rio"""
# print(handle(sentences, trigger))
dm_key = list(dm1.ontology_dict.keys())
print(len(dm_key))
def get_tmp(index,evt_type):
if index is None or evt_type is None:
return ''
input_template = eval("dm{}.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']".format(index+1))
return '\n'.join(input_template)
with gr.Blocks() as demo:
with gr.Row().style(equal_height=False):
with gr.Column(variant="panel"):
stens = gr.Text(label='文档')
evt_type = gr.Dropdown(choices=dm_key, label='事件类型')
trigger = gr.Text(label='触发词')
temp = gr.Dropdown(choices=['基础模板', '简单子模板', '融入语义信息的子模板', '融入论元信息的子模板'],
type='index', value='基础模板', label='模板')
output_tmp = gr.Text(label='模板内容')
btn = gr.Button("Run")
with gr.Column(variant="panel"):
result = gr.Text(label='输出')
evt_type.change(get_tmp,inputs=[temp,evt_type],outputs=[output_tmp])
temp.change(get_tmp,inputs=[temp,evt_type],outputs=[output_tmp])
btn.click(fn=handle, inputs=[stens,trigger,temp,evt_type], outputs=[result])
demo.launch(server_name='0.0.0.0',server_port=6006,share=True)