|
|
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
import glob |
|
import json |
|
from pandas.io.json import json_normalize |
|
from nltk.tokenize import sent_tokenize |
|
import nltk |
|
import scipy.spatial |
|
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering |
|
from sentence_transformers import models, SentenceTransformer |
|
import torch |
|
import spacy |
|
import streamlit as st |
|
from utils import * |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_prep_data(): |
|
with open('listfile_3.data', 'rb') as filehandle: |
|
articles = pickle.load(filehandle) |
|
|
|
for article in range(len(articles)): |
|
if articles[article][1] != []: |
|
articles[article][1] = sent_tokenize(articles[article][1]) |
|
|
|
return articles |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def build_sent_trans_model(): |
|
|
|
word_embedding_model = models.BERT('covidbert_nli') |
|
|
|
|
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), |
|
pooling_mode_mean_tokens=True, |
|
pooling_mode_cls_token=False, |
|
pooling_mode_max_tokens=False) |
|
|
|
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) |
|
return model |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_embedded_articles(): |
|
with open('list_of_articles.pkl', 'rb') as f: |
|
list_of_articles = pickle.load(f) |
|
|
|
return list_of_articles |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_comprehension_model(): |
|
|
|
comprehension_model = pipeline("question-answering", |
|
model=AutoModelForQuestionAnswering.\ |
|
from_pretrained("graviraja/covidbert_squad"), |
|
tokenizer=AutoTokenizer.\ |
|
from_pretrained("graviraja/covidbert_squad"), |
|
device=-1) |
|
|
|
return comprehension_model |
|
|
|
|
|
|
|
def main(): |
|
|
|
nltk.download('punkt') |
|
spacy_nlp = spacy.load('en_core_web_sm') |
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() |
|
else 'cpu') |
|
|
|
embeddings = load_prep_data() |
|
|
|
model = build_sent_trans_model() |
|
model.to(device) |
|
|
|
list_of_articles = load_embedded_articles() |
|
|
|
comprehension_model = load_comprehension_model() |
|
|
|
query = st.text_input("Enter Query",'example query ',key="query") |
|
|
|
query_embedding, results1 = fetch_stage1(query, model, list_of_articles) |
|
|
|
results2 = fetch_stage2(results1, model, embeddings, query_embedding) |
|
|
|
results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp) |
|
|
|
if results3: |
|
count = 1 |
|
|
|
for res in results3: |
|
st.write('{}> {}'.format(count, res[2])) |
|
st.write('Score: %.4f' % (res[1])) |
|
st.write("From the article with title: {}".format(embeddings[res[0]][0])) |
|
st.write("\n") |
|
|
|
|
|
|
|
if count > 3: |
|
break |
|
count += 1 |
|
else: |
|
st.info("There isn't any answer") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|