import os import re import argparse import pickle import tokenizers import qarac.corpora.BNCorpus import qarac.corpora.Batcher import qarac.models.qarac_base_model import keras import tensorflow import spacy import pandas import qarac.utils.CoreferenceResolver def decoder_loss(y_true,y_pred): return keras.losses.sparse_categorical_crossentropy(y_true, y_pred.logits, logits=True) def capitalise(token,i): return token.text_with_ws.title() if i==0 or token.tag_.startswith('NNP') else token.text_with_ws.lower() def clean_question(doc): words = [capitalise(token,i) for (i,token) in enumerate(doc)] if words[-1]!='?': words.append('?') return ''.join(words) def prepare_wiki_qa(filename,outfilename): data = pandas.read_csv(filename,sep='\t') data['QNum']=data['QuestionID'].apply(lambda x: int(x[1:])) nlp = spacy.load('en_core_web_trf') predictor = qarac.utils.CoreferenceResolver.CoreferenceResolver() data['Resolved_answer'] = data.groupby('QNum')['Sentence'].transform(predictor) unique_questions = data.groupby('QNum')['Question'].first() cleaned_questions = pandas.Series([clean_question(doc) for doc in nlp.pipe(unique_questions)], index = unique_questions.index) for (i,question) in cleaned_questions.items(): data.loc[data['QNum']==i,'Cleaned_question']=question data[['Cleaned_question','Resolved_answer','Label']].to_csv(outfilename) def train_base_model(task,filename): tokenizer = tokenizers.Tokenizer.from_pretrained('xlm-roberta-base') tokenizer.add_special_tokens(['','','']) tokenizer.save('/'.join([os.environ['HOME'], 'QARAC', 'models', 'tokenizer.json'])) bnc = qarac.corpora.BNCorpus.BNCorpus(tokenizer=tokenizer, task=task) (train,test)=bnc.split(0.01) train_data=qarac.corpora.Batcher.Batcher(train) model = qarac.models.qarac_base_model.qarac_base_model(tokenizer.get_vocab_size(), 768, 12, task=='decode') optimizer = keras.optimizers.Nadam(learning_rate=keras.optimizers.schedules.ExponentialDecay(1.0e-5, 100, 0.99)) model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics='accuracy') model.fit(train_data, epochs=100, workers = 16, use_multiprocessing=True) test_data=qarac.corpora.Batcher.Batcher(test) print(model.evaluate(test_data)) model.save(filename) if __name__ == '__main__': parser = argparse.ArgumentParser(prog='QARAC', description='Experimental NLP system, aimed at improving factual accuracy') parser.add_argument('task') parser.add_argument('-f','--filename') parser.add_argument('-t','--training-task') parser.add_argument('-o','--outputfile') args = parser.parse_args() if args.task == 'train_base_model': train_base_model(args.training_task,args.filename) elif args.task == 'prepare_wiki_qa': prepare_wiki_qa(args.filename,args.outputfile)