Spaces:
Build error
Build error
import difflib | |
import os | |
import json | |
from tqdm import tqdm | |
from glob import glob | |
# | |
# if not os.path.exists('./evttgr2type.json'): | |
# for file_name in glob('data/RAMS_1.0/data/test.jsonlines'): | |
# dic = {} | |
# with open(file_name,'r',encoding='utf-8') as f: | |
# lines = f.readlines() | |
# for line in tqdm(lines): | |
# linej = json.loads(line.strip()) | |
# evt_triggers = linej['evt_triggers'] | |
# # print(evt_triggers) | |
# sentences = linej['sentences'] | |
# # print(sentences) | |
# sentences_uni = [] | |
# for s in sentences: | |
# sentences_uni += s | |
# print(' '.join(sentences_uni)) | |
# triggers = ' '.join(sentences_uni[evt_triggers[0][0]:evt_triggers[0][1]+1]) | |
# evt_type = evt_triggers[0][2][0][0] | |
# if triggers in dic: | |
# if dic[triggers] != evt_type: | |
# print('一个触发词有不同的事件类型: {} {} {}'.format(triggers,evt_type,dic[triggers])) | |
# dic[triggers] = evt_type | |
# print(evt_type, triggers) | |
# exit() | |
import argparse | |
import jsonlines | |
import torch | |
from src.genie.data import my_collate | |
from src.genie.data_module_w import RAMSDataModule | |
from src.genie.model import GenIEModel | |
import gradio as gr | |
import re | |
from transformers import BartTokenizer | |
MAX_LENGTH = 424 | |
MAX_TGT_LENGTH = 72 | |
DOC_STRIDE = 256 | |
class DataModule4(): | |
def __init__(self, ontology_file): | |
super().__init__() | |
self.ontology_file = ontology_file | |
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | |
self.tokenizer.add_tokens([' <arg>', ' <tgr>']) | |
self.ontology_dict = self.load_ontology() | |
def create_gold_gen(self, context_words, evt_type, trigger): | |
# 设置三个总列表、存放输入模板、输出模板 | |
INPUT = [] | |
CONTEXT = [] | |
input_template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] | |
i = len(input_template) | |
input_list = [] | |
for x in range(i): | |
str = re.sub(r'<arg\d>', '<arg>', input_template[x]) | |
input_list.append(str) | |
# 其中input_list种存放的是 原始数据中<arg1> 全部替换为 <arg> 之后的模板 下一步应该进行分词 | |
temp = [] | |
for x in range(i): | |
space_tokenized_template = input_list[x].split(' ') | |
temp.append(space_tokenized_template) | |
# 其中temp中存放的都是分词后的模板 下一步对temp中的所有元素进行tokenize | |
tokenized_input_template = [] | |
for x in range(len(temp)): | |
for w in temp[x]: | |
tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) | |
INPUT.append(tokenized_input_template) | |
tokenized_input_template = [] | |
context_words = context_words.split(' ') | |
trigger_words = trigger.split(' ') | |
trigger_span_start = context_words.index(trigger_words[0]) | |
trigger_span_end = context_words.index(trigger_words[-1]) | |
# 触发词之前的单词 | |
prefix = self.tokenizer.tokenize(' '.join(context_words[:trigger_span_start]), add_prefix_space=True) | |
# 触发词短语 | |
tgt = self.tokenizer.tokenize(trigger, add_prefix_space=True) | |
# 触发词之后的单词 | |
suffix = self.tokenizer.tokenize(' '.join(context_words[trigger_span_end+1:]), add_prefix_space=True) | |
context = prefix + [' <tgr>', ] + tgt + [' <tgr>', ] + suffix | |
# context = self.tokenizer.tokenize(' '.join(context_words), add_prefix_space=True) | |
# 将context放入CONTEXT中 | |
for w in range(i): | |
CONTEXT.append(context) | |
return INPUT, CONTEXT | |
def load_ontology(self): | |
ontology_dict = {} | |
with open(self.ontology_file, 'r') as f: | |
for lidx, line in enumerate(f): | |
if lidx == 0: # header | |
continue | |
fields = line.strip().split(',') | |
if len(fields) < 2: | |
break | |
evt_type = fields[0] | |
if evt_type in ontology_dict.keys(): | |
args = fields[2:] | |
ontology_dict[evt_type]['template'].append(fields[1]) | |
for i, arg in enumerate(args): | |
if arg != '': | |
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg | |
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) | |
else: | |
ontology_dict[evt_type] = {} | |
args = fields[2:] | |
ontology_dict[evt_type]['template'] = [] | |
ontology_dict[evt_type]['template'].append(fields[1]) | |
for i, arg in enumerate(args): | |
if arg != '': | |
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg | |
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) | |
return ontology_dict | |
def prepare_data(self, sentences, evt_type, trigger): | |
input_template, context = self.create_gold_gen(sentences, evt_type, trigger) | |
length = len(input_template) | |
# print(input_template) | |
# print(context) | |
results = [] | |
for i in range(length): | |
input_tokens = self.tokenizer.encode_plus(input_template[i], context[i], | |
add_special_tokens=True, | |
add_prefix_space=True, | |
max_length=MAX_LENGTH, | |
truncation='only_second', | |
padding='max_length') | |
# input_ids 单词在词典中的编码 | |
results.append(input_tokens['input_ids']) | |
temp = self.ontology_dict[evt_type.replace('n/a', 'unspecified')] | |
return results, temp | |
class DataModuleW(): | |
def __init__(self, ontology_file): | |
super().__init__() | |
self.ontology_file = ontology_file | |
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | |
self.tokenizer.add_tokens([' <arg>', ' <tgr>']) | |
self.ontology_dict = self.load_ontology() | |
def create_gold_gen(self, context_words, evt_type, trigger): | |
# 设置三个总列表、存放输入模板、输出模板 | |
INPUT = [] | |
CONTEXT = [] | |
input_template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] | |
i = len(input_template) | |
input_list = [] | |
for x in range(i): | |
str = re.sub('<trg>', trigger, input_template[x]) | |
str = re.sub('<trg>', trigger, str) | |
input_list.append(str) | |
# 其中input_list种存放的是 原始数据中<arg1> 全部替换为 <arg> 之后的模板 下一步应该进行分词 | |
temp = [] | |
for x in range(i): | |
space_tokenized_template = input_list[x].split(' ') | |
temp.append(space_tokenized_template) | |
# 其中temp中存放的都是分词后的模板 下一步对temp中的所有元素进行tokenize | |
tokenized_input_template = [] | |
for x in range(len(temp)): | |
for w in temp[x]: | |
tokenized_input_template.extend(self.tokenizer.tokenize(w, add_prefix_space=True)) | |
INPUT.append(tokenized_input_template) | |
tokenized_input_template = [] | |
template = self.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template'] | |
for y in range(len(template)): | |
template[y] = re.sub('<trg>', trigger, template[y]) | |
context = self.tokenizer.tokenize(context_words, add_prefix_space=True) | |
# 将context放入CONTEXT中 | |
for w in range(i): | |
CONTEXT.append(context) | |
return INPUT, CONTEXT | |
def load_ontology(self): | |
ontology_dict = {} | |
with open(self.ontology_file, 'r') as f: | |
for lidx, line in tqdm(enumerate(f)): | |
if lidx == 0: # header | |
continue | |
fields = line.strip().split(',') | |
if len(fields) < 2: | |
break | |
evt_type = fields[0] | |
if evt_type in ontology_dict.keys(): | |
args = fields[2:] | |
ontology_dict[evt_type]['template'].append(fields[1]) | |
for i, arg in enumerate(args): | |
if arg != '': | |
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg | |
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) | |
else: | |
ontology_dict[evt_type] = {} | |
args = fields[2:] | |
ontology_dict[evt_type]['template'] = [] | |
ontology_dict[evt_type]['template'].append(fields[1]) | |
for i, arg in enumerate(args): | |
if arg != '': | |
ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg | |
ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1) | |
return ontology_dict | |
def prepare_data(self, sentences, evt_type, trigger): | |
input_template, context = self.create_gold_gen(sentences, evt_type, trigger) | |
length = len(input_template) | |
# print(input_template) | |
# print(output_template) | |
# print(context) | |
results = [] | |
for i in range(length): | |
input_tokens = self.tokenizer.encode_plus(input_template[i], context[i], | |
add_special_tokens=True, | |
add_prefix_space=True, | |
max_length=MAX_LENGTH, | |
truncation='only_second', | |
padding='max_length') | |
# input_ids 单词在词典中的编码 | |
results.append(input_tokens['input_ids']) | |
temp = self.ontology_dict[evt_type.replace('n/a', 'unspecified')] | |
return results, temp | |
class Runner(): | |
def __init__(self, load_ckpt = 'checkpoints/gen-RAMS-what-new-span/epoch=2-v0.ckpt'): | |
model = 'gen' | |
self.ckpt_name = 'gen-RAMS-pred' | |
self.load_ckpt = load_ckpt | |
self.dataset = 'RAMS' | |
self.eval_only = True | |
self.train_file = 'data/RAMS_1.0/data/train.jsonlines' | |
self.val_file = 'data/RAMS_1.0/data/dev.jsonlines' | |
self.test_file = 'data/RAMS_1.0/data/test.jsonlines' | |
self.train_batch_size = 2 | |
self.eval_batch_size = 4 | |
self.learning_rate = 3e-5 | |
self.accumulate_grad_batches = 4 | |
self.num_train_epochs = 3 | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument( | |
"--model", | |
type=str, | |
default=model | |
) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
default=self.dataset | |
) | |
parser.add_argument('--tmp_dir', type=str) | |
parser.add_argument( | |
"--ckpt_name", | |
default=self.ckpt_name, | |
type=str, | |
help="The output directory where the model checkpoints and predictions will be written.", | |
) | |
parser.add_argument( | |
"--load_ckpt", | |
default=self.load_ckpt, | |
type=str, | |
) | |
parser.add_argument( | |
"--train_file", | |
default=self.train_file, | |
type=str, | |
help="The input training file. If a data dir is specified, will look for the file there" | |
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", | |
) | |
parser.add_argument( | |
"--val_file", | |
default=self.val_file, | |
type=str, | |
help="The input evaluation file. If a data dir is specified, will look for the file there" | |
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", | |
) | |
parser.add_argument( | |
'--test_file', | |
type=str, | |
default=self.test_file, | |
) | |
parser.add_argument('--input_dir', type=str, default=None) | |
parser.add_argument('--coref_dir', type=str, default='data/kairos/coref_outputs') | |
parser.add_argument('--use_info', action='store_true', default=False, | |
help='use informative mentions instead of the nearest mention.') | |
parser.add_argument('--mark_trigger', action='store_true') | |
parser.add_argument('--sample-gen', action='store_true', help='Do sampling when generation.') | |
parser.add_argument("--train_batch_size", default=self.train_batch_size, type=int, | |
help="Batch size per GPU/CPU for training.") | |
parser.add_argument( | |
"--eval_batch_size", default=self.eval_batch_size, type=int, help="Batch size per GPU/CPU for evaluation." | |
) | |
parser.add_argument("--learning_rate", default=self.learning_rate, type=float, | |
help="The initial learning rate for Adam.") | |
parser.add_argument( | |
"--accumulate_grad_batches", | |
type=int, | |
default=self.accumulate_grad_batches, | |
help="Number of updates steps to accumulate before performing a backward/update pass.", | |
) | |
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") | |
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") | |
parser.add_argument("--gradient_clip_val", default=1.0, type=float, help="Max gradient norm.") | |
parser.add_argument( | |
"--num_train_epochs", default=self.num_train_epochs, type=int, | |
help="Total number of training epochs to perform." | |
) | |
parser.add_argument( | |
"--max_steps", | |
default=-1, | |
type=int, | |
help="If > 0: set total number of training steps to perform. Override num_train_epochs.", | |
) | |
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") | |
parser.add_argument("--gpus", default=None, help='-1 means train on all gpus') | |
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") | |
parser.add_argument( | |
"--fp16", | |
action="store_true", | |
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", | |
) | |
parser.add_argument("--threads", type=int, default=1, | |
help="multiple threads for converting example to features") | |
self.args = parser.parse_args() | |
self.model = GenIEModel(self.args) | |
self.model.load_state_dict(torch.load(self.args.load_ckpt, map_location=self.model.device)['state_dict']) | |
def pred(self,input): | |
x = torch.stack([torch.LongTensor(u) for u in input]) | |
return self.model.pred(x) | |
print('Loading data...') | |
dm1 = DataModule4('aida_ontology_cleaned.csv') | |
dm2 = DataModuleW('aida_ontology_fj-w-2.csv') | |
dm3 = DataModuleW('aida_ontology_fj-w-3.csv') | |
dm4 = DataModule4('aida_ontology_fj-5.csv') | |
print('Loading Model 1...') | |
runner1 = Runner('checkpoints/gen-RAMS-1-span/epoch=2-v1.ckpt') | |
print('Loading Model 2...') | |
runner2 = Runner('checkpoints/gen-RAMS-2-span/epoch=2-v0.ckpt') | |
print('Loading Model 3...') | |
runner3 = Runner('checkpoints/gen-RAMS-3-span/epoch=2-v0.ckpt') | |
print('Loading Model 4...') | |
runner4 = Runner('checkpoints/gen-RAMS-4-span/epoch=2-v0.ckpt') | |
def handle(sentences,trigger, temp=3, evt_type='contact.prevarication.broadcast'): | |
x, argnames = eval('dm{}.prepare_data(sentences,evt_type,trigger)'.format(temp+1)) | |
ys = eval('runner{}.pred(x)'.format(temp+1)) | |
print(ys) | |
results = [] | |
for y in ys: | |
while ' ' in y: | |
y = y.replace(' ', ' ') | |
result = y.strip(' ').split(' ') | |
results.append(result) | |
print(results) | |
argss = [] | |
for n,template in enumerate(argnames['template']): | |
template = template.split(' ') | |
# print(template) | |
args = [] | |
for i, w in enumerate(template): | |
if '<arg' in w: | |
m = re.match(r'evt\d+arg\d+(\w+)', argnames[re.match(r'<(\w+)>', w).group(1)]) | |
if m: | |
label = m.group(1) | |
if results[n][i] == '<arg>': | |
args.append(label+': None') | |
else: | |
args.append(label+': '+results[n][i]) | |
argss.append(', '.join(args)) | |
return '\n'.join(argss) | |
if __name__ == "__main__": | |
# trigger = 'deceive' | |
# sentences = """We are ashamed of them . " However , Mutko stopped short of admitting the doping scandal was state sponsored . " We are very sorry that athletes who tried to deceive us , and the world , were not caught sooner . We are very sorry because Russia is committed to upholding the highest standards in sport and is opposed to anything that threatens the Olympic values , " he said . English former heptathlete and Athens 2004 bronze medallist Kelly Sotherton was unhappy with Mutko 's plea for Russia 's ban to be lifted for Rio""" | |
# print(handle(sentences, trigger)) | |
dm_key = list(dm1.ontology_dict.keys()) | |
print(len(dm_key)) | |
def get_tmp(index,evt_type): | |
if index is None or evt_type is None: | |
return '' | |
input_template = eval("dm{}.ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']".format(index+1)) | |
return '\n'.join(input_template) | |
with gr.Blocks() as demo: | |
with gr.Row().style(equal_height=False): | |
with gr.Column(variant="panel"): | |
stens = gr.Text(label='文档') | |
evt_type = gr.Dropdown(choices=dm_key, label='事件类型') | |
trigger = gr.Text(label='触发词') | |
temp = gr.Dropdown(choices=['基础模板', '简单子模板', '融入语义信息的子模板', '融入论元信息的子模板'], | |
type='index', value='基础模板', label='模板') | |
output_tmp = gr.Text(label='模板内容') | |
btn = gr.Button("Run") | |
with gr.Column(variant="panel"): | |
result = gr.Text(label='输出') | |
evt_type.change(get_tmp,inputs=[temp,evt_type],outputs=[output_tmp]) | |
temp.change(get_tmp,inputs=[temp,evt_type],outputs=[output_tmp]) | |
btn.click(fn=handle, inputs=[stens,trigger,temp,evt_type], outputs=[result]) | |
demo.launch(server_name='0.0.0.0',server_port=6006,share=True) | |