import pandas as pd import numpy as np import pickle import glob import json from pandas.io.json import json_normalize from nltk.tokenize import sent_tokenize import nltk import scipy.spatial from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering from sentence_transformers import models, SentenceTransformer import torch import spacy import streamlit as st from utils import * @st.cache(allow_output_mutation=True) def load_prep_data(): with open('listfile_3.data', 'rb') as filehandle: articles = pickle.load(filehandle) for article in range(len(articles)): if articles[article][1] != []: articles[article][1] = sent_tokenize(articles[article][1]) return articles @st.cache(allow_output_mutation=True) def build_sent_trans_model(): word_embedding_model = models.BERT('covidbert_nli') # Add the pooling strategy of Mean pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) return model @st.cache(allow_output_mutation=True) def load_embedded_articles(): with open('list_of_articles.pkl', 'rb') as f: list_of_articles = pickle.load(f) return list_of_articles @st.cache(allow_output_mutation=True) def load_comprehension_model(): # device is set to -1 to use the available gpu comprehension_model = pipeline("question-answering", model=AutoModelForQuestionAnswering.\ from_pretrained("graviraja/covidbert_squad"), tokenizer=AutoTokenizer.\ from_pretrained("graviraja/covidbert_squad"), device=-1) return comprehension_model def main(): nltk.download('punkt') spacy_nlp = spacy.load('en_core_web_sm') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') embeddings = load_prep_data() model = build_sent_trans_model() model.to(device) list_of_articles = load_embedded_articles() comprehension_model = load_comprehension_model() query = st.text_input("Enter Query",'example query ',key="query") query_embedding, results1 = fetch_stage1(query, model, list_of_articles) results2 = fetch_stage2(results1, model, embeddings, query_embedding) results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp) if results3: count = 1 for res in results3: st.write('{}> {}'.format(count, res[2])) st.write('Score: %.4f' % (res[1])) st.write("From the article with title: {}".format(embeddings[res[0]][0])) st.write("\n") # print(count,". ", res[2], "(Score: %.4f)" % (res[1])) # print("From the article with title: ", embeddings[res[0]][0]) # print("\n") if count > 3: break count += 1 else: st.info("There isn't any answer") if __name__ == '__main__': main()