COVID_NLI / app.py
hitz02's picture
Upload 3 files
07ab211
raw
history blame
3.2 kB
import pandas as pd
import numpy as np
import pickle
import glob
import json
from pandas.io.json import json_normalize
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
import torch
import spacy
import streamlit as st
from utils import *
@st.cache(allow_output_mutation=True)
def load_prep_data():
with open('listfile_3.data', 'rb') as filehandle:
articles = pickle.load(filehandle)
for article in range(len(articles)):
if articles[article][1] != []:
articles[article][1] = sent_tokenize(articles[article][1])
return articles
@st.cache(allow_output_mutation=True)
def build_sent_trans_model():
word_embedding_model = models.BERT('covidbert_nli')
# Add the pooling strategy of Mean
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
return model
@st.cache(allow_output_mutation=True)
def load_embedded_articles():
with open('list_of_articles.pkl', 'rb') as f:
list_of_articles = pickle.load(f)
return list_of_articles
@st.cache(allow_output_mutation=True)
def load_comprehension_model():
# device is set to -1 to use the available gpu
comprehension_model = pipeline("question-answering",
model=AutoModelForQuestionAnswering.\
from_pretrained("graviraja/covidbert_squad"),
tokenizer=AutoTokenizer.\
from_pretrained("graviraja/covidbert_squad"),
device=-1)
return comprehension_model
def main():
nltk.download('punkt')
spacy_nlp = spacy.load('en_core_web_sm')
device = torch.device('cuda:0' if torch.cuda.is_available()
else 'cpu')
embeddings = load_prep_data()
model = build_sent_trans_model()
model.to(device)
list_of_articles = load_embedded_articles()
comprehension_model = load_comprehension_model()
query = st.text_input("Enter Query",'example query ',key="query")
query_embedding, results1 = fetch_stage1(query, model, list_of_articles)
results2 = fetch_stage2(results1, model, embeddings, query_embedding)
results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)
if results3:
count = 1
for res in results3:
st.write('{}> {}'.format(count, res[2]))
st.write('Score: %.4f' % (res[1]))
st.write("From the article with title: {}".format(embeddings[res[0]][0]))
st.write("\n")
# print(count,". ", res[2], "(Score: %.4f)" % (res[1]))
# print("From the article with title: ", embeddings[res[0]][0])
# print("\n")
if count > 3:
break
count += 1
else:
st.info("There isn't any answer")
if __name__ == '__main__':
main()