import gradio as gr import numpy as np import time import hashlib import torch from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, pipeline from tqdm import tqdm import os device = "cuda:0" if torch.cuda.is_available() else "cpu" import textract from scipy.special import softmax import pandas as pd from datetime import datetime tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1") model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1").to(device).eval() tokenizer_ans = AutoTokenizer.from_pretrained("deepset/roberta-large-squad2") model_ans = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-large-squad2").to(device).eval() if device == 'cuda:0': pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans,device = 0) else: pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans) def cls_pooling(model_output): return model_output.last_hidden_state[:,0] def encode_query(query): encoded_input = tokenizer(query, truncation=True, return_tensors='pt').to(device) with torch.no_grad(): model_output = model(**encoded_input, return_dict=True) embeddings = cls_pooling(model_output) return embeddings.cpu() def encode_docs(docs,maxlen = 64, stride = 32): encoded_input = [] embeddings = [] spans = [] file_names = [] name, text = docs text = text.split(" ") if len(text) < maxlen: text = " ".join(text) encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device)) spans.append(temp_text) file_names.append(name) else: num_iters = int(len(text)/maxlen)+1 for i in range(num_iters): if i == 0: temp_text = " ".join(text[i*maxlen:(i+1)*maxlen+stride]) else: temp_text = " ".join(text[(i-1)*maxlen:(i)*maxlen][-stride:] + text[i*maxlen:(i+1)*maxlen]) encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device)) spans.append(temp_text) file_names.append(name) with torch.no_grad(): for encoded in tqdm(encoded_input): model_output = model(**encoded, return_dict=True) embeddings.append(cls_pooling(model_output)) embeddings = np.float32(torch.stack(embeddings).transpose(0, 1).cpu()) np.save("emb_{}.npy".format(name),dict(zip(list(range(len(embeddings))),embeddings))) np.save("spans_{}.npy".format(name),dict(zip(list(range(len(spans))),spans))) np.save("file_{}.npy".format(name),dict(zip(list(range(len(file_names))),file_names))) return embeddings, spans, file_names def predict(query,data): name_to_save = data.name.split("\\")[-1].split(".")[0][:-8] st = str([query,name_to_save]) hist = st + " " + str(hashlib.sha256(st.encode()).hexdigest()) now = datetime.now() current_time = now.strftime("%H:%M:%S") try: df = pd.read_csv("{}.csv".format(hash(st))) return df except Exception as e: print(e) print(st) if name_to_save+".txt" in os.listdir("text_gradio"): doc_emb = np.load('emb_{}.npy'.format(name_to_save),allow_pickle='TRUE').item() doc_text = np.load('spans_{}.npy'.format(name_to_save),allow_pickle='TRUE').item() file_names_dicto = np.load('file_{}.npy'.format(name_to_save),allow_pickle='TRUE').item() doc_emb = np.array(list(doc_emb.values())).reshape(-1,768) doc_text = list(doc_text.values()) file_names = list(file_names_dicto.values()) else: text = textract.process("{}".format(data.name)).decode('utf8') text = text.replace("\r", " ") text = text.replace("\n", " ") text = text.replace(" . "," ") doc_emb, doc_text, file_names = encode_docs((name_to_save,text),maxlen = 64, stride = 32) doc_emb = doc_emb.reshape(-1, 768) with open("{}.txt".format(name_to_save),"w",encoding="utf-8") as f: f.write(text) start = time.time() query_emb = encode_query(query) scores = np.matmul(query_emb, doc_emb.transpose(1,0))[0].tolist() doc_score_pairs = list(zip(doc_text, scores, file_names)) doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True) k = 5 probs_sum = 0 probs = softmax(sorted(scores,reverse = True)[:k]) table = {"Passage":[],"Answer":[],"Probabilities":[],"Source":[]} for i, (passage, _, names) in enumerate(doc_score_pairs[:k]): passage = passage.replace("\n","") passage = passage.replace(" . "," ") if probs[i] > 0.1 or (i < 3 and probs[i] > 0.05): #generate answers for more likely passages but no less than 2 QA = {'question':query,'context':passage} ans = pipe(QA) probabilities = "P(a|p): {}, P(a|p,q): {}, P(p|q): {}".format(round(ans["score"],5), round(ans["score"]*probs[i],5), round(probs[i],5)) passage = passage.replace(str(ans["answer"]),str(ans["answer"]).upper()) table["Passage"].append(passage) table["Passage"].append("---") table["Answer"].append(str(ans["answer"]).upper()) table["Answer"].append("---") table["Probabilities"].append(probabilities) table["Probabilities"].append("---") table["Source"].append(names) table["Source"].append("---") else: table["Passage"].append(passage) table["Passage"].append("---") table["Answer"].append("no_answer_calculated") table["Answer"].append("---") table["Probabilities"].append("P(p|q): {}".format(round(probs[i],5))) table["Probabilities"].append("---") table["Source"].append(names) table["Source"].append("---") df = pd.DataFrame(table) print("time: "+ str(time.time()-start)) with open("HISTORY.txt","a", encoding = "utf-8") as f: f.write(hist) f.write(" " + str(current_time)) f.write("\n") f.close() df.to_csv("{}.csv".format(hash(st)), index=False) return df iface = gr.Interface( fn =predict, inputs = [gr.inputs.Textbox(default="What is Open-domain question answering?"), gr.inputs.File(), ], outputs = [ gr.outputs.Dataframe(), ], allow_flagging ="manual",flagging_options = ["correct","wrong"], allow_screenshot=False) iface.launch(share = True,enable_queue=True, show_error =True)