Upload 3 files
Browse files- app.py +109 -0
- requirements.txt +9 -0
- utils.py +123 -0
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
import glob
|
6 |
+
import json
|
7 |
+
from pandas.io.json import json_normalize
|
8 |
+
from nltk.tokenize import sent_tokenize
|
9 |
+
import nltk
|
10 |
+
import scipy.spatial
|
11 |
+
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
|
12 |
+
from sentence_transformers import models, SentenceTransformer
|
13 |
+
import torch
|
14 |
+
import spacy
|
15 |
+
import streamlit as st
|
16 |
+
from utils import *
|
17 |
+
|
18 |
+
|
19 |
+
@st.cache(allow_output_mutation=True)
|
20 |
+
def load_prep_data():
|
21 |
+
with open('listfile_3.data', 'rb') as filehandle:
|
22 |
+
articles = pickle.load(filehandle)
|
23 |
+
|
24 |
+
for article in range(len(articles)):
|
25 |
+
if articles[article][1] != []:
|
26 |
+
articles[article][1] = sent_tokenize(articles[article][1])
|
27 |
+
|
28 |
+
return articles
|
29 |
+
|
30 |
+
@st.cache(allow_output_mutation=True)
|
31 |
+
def build_sent_trans_model():
|
32 |
+
|
33 |
+
word_embedding_model = models.BERT('covidbert_nli')
|
34 |
+
|
35 |
+
# Add the pooling strategy of Mean
|
36 |
+
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
|
37 |
+
pooling_mode_mean_tokens=True,
|
38 |
+
pooling_mode_cls_token=False,
|
39 |
+
pooling_mode_max_tokens=False)
|
40 |
+
|
41 |
+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
42 |
+
return model
|
43 |
+
|
44 |
+
@st.cache(allow_output_mutation=True)
|
45 |
+
def load_embedded_articles():
|
46 |
+
with open('list_of_articles.pkl', 'rb') as f:
|
47 |
+
list_of_articles = pickle.load(f)
|
48 |
+
|
49 |
+
return list_of_articles
|
50 |
+
|
51 |
+
@st.cache(allow_output_mutation=True)
|
52 |
+
def load_comprehension_model():
|
53 |
+
# device is set to -1 to use the available gpu
|
54 |
+
comprehension_model = pipeline("question-answering",
|
55 |
+
model=AutoModelForQuestionAnswering.\
|
56 |
+
from_pretrained("graviraja/covidbert_squad"),
|
57 |
+
tokenizer=AutoTokenizer.\
|
58 |
+
from_pretrained("graviraja/covidbert_squad"),
|
59 |
+
device=-1)
|
60 |
+
|
61 |
+
return comprehension_model
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def main():
|
66 |
+
|
67 |
+
nltk.download('punkt')
|
68 |
+
spacy_nlp = spacy.load('en_core_web_sm')
|
69 |
+
|
70 |
+
device = torch.device('cuda:0' if torch.cuda.is_available()
|
71 |
+
else 'cpu')
|
72 |
+
|
73 |
+
embeddings = load_prep_data()
|
74 |
+
|
75 |
+
model = build_sent_trans_model()
|
76 |
+
model.to(device)
|
77 |
+
|
78 |
+
list_of_articles = load_embedded_articles()
|
79 |
+
|
80 |
+
comprehension_model = load_comprehension_model()
|
81 |
+
|
82 |
+
query = st.text_input("Enter Query",'example query ',key="query")
|
83 |
+
|
84 |
+
query_embedding, results1 = fetch_stage1(query, model, list_of_articles)
|
85 |
+
|
86 |
+
results2 = fetch_stage2(results1, model, embeddings, query_embedding)
|
87 |
+
|
88 |
+
results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)
|
89 |
+
|
90 |
+
if results3:
|
91 |
+
count = 1
|
92 |
+
|
93 |
+
for res in results3:
|
94 |
+
st.write('{}> {}'.format(count, res[2]))
|
95 |
+
st.write('Score: %.4f' % (res[1]))
|
96 |
+
st.write("From the article with title: {}".format(embeddings[res[0]][0]))
|
97 |
+
st.write("\n")
|
98 |
+
# print(count,". ", res[2], "(Score: %.4f)" % (res[1]))
|
99 |
+
# print("From the article with title: ", embeddings[res[0]][0])
|
100 |
+
# print("\n")
|
101 |
+
if count > 3:
|
102 |
+
break
|
103 |
+
count += 1
|
104 |
+
else:
|
105 |
+
st.info("There isn't any answer")
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
nltk
|
3 |
+
pandas
|
4 |
+
scipy
|
5 |
+
numpy
|
6 |
+
sentence-transformers==0.2.5.1
|
7 |
+
transformers==2.5.1
|
8 |
+
spacy
|
9 |
+
streamlit
|
utils.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
import glob
|
6 |
+
import json
|
7 |
+
from pandas.io.json import json_normalize
|
8 |
+
from nltk.tokenize import sent_tokenize
|
9 |
+
import nltk
|
10 |
+
import scipy.spatial
|
11 |
+
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
|
12 |
+
from sentence_transformers import models, SentenceTransformer
|
13 |
+
|
14 |
+
|
15 |
+
def get_full_sentence(spacy_nlp, para_text, start_index, end_index):
|
16 |
+
"""
|
17 |
+
Returns the relative sentence of original text,
|
18 |
+
given a specific paragraph (body text).
|
19 |
+
"""
|
20 |
+
sent_start = 0
|
21 |
+
sent_end = len(para_text)
|
22 |
+
for sent in spacy_nlp(para_text).sents:
|
23 |
+
if (sent.start_char <= start_index) and (sent.end_char >= start_index):
|
24 |
+
sent_start = sent.start_char
|
25 |
+
if (sent.start_char <= end_index) and (sent.end_char >= end_index):
|
26 |
+
sent_end = sent.end_char
|
27 |
+
sentence = para_text[sent_start:sent_end + 1]
|
28 |
+
return sentence
|
29 |
+
|
30 |
+
|
31 |
+
def fetch_stage1(query, model, list_of_articles):
|
32 |
+
"""
|
33 |
+
Compare all the articles' abstract content with each query
|
34 |
+
"""
|
35 |
+
|
36 |
+
# Encode queries
|
37 |
+
query_embedding = model.encode([query])[0]
|
38 |
+
|
39 |
+
|
40 |
+
all_abs_distances = []
|
41 |
+
|
42 |
+
for idx_of_article,article in enumerate(list_of_articles):
|
43 |
+
if article:
|
44 |
+
distances = []
|
45 |
+
cdists = scipy.spatial.distance.cdist([query_embedding], np.vstack(article), "cosine").reshape(-1,1)
|
46 |
+
for idx,sentence in enumerate(article):
|
47 |
+
distances.append((idx, 1 - cdists[idx][0]))
|
48 |
+
|
49 |
+
results = sorted(distances, key=lambda x: x[1], reverse=True)
|
50 |
+
if results:
|
51 |
+
all_abs_distances.append((idx_of_article, results[0][0], results[0][1]))
|
52 |
+
|
53 |
+
results = sorted(all_abs_distances, key=lambda x: x[2], reverse=True)
|
54 |
+
|
55 |
+
return query_embedding, results
|
56 |
+
|
57 |
+
|
58 |
+
def fetch_stage2(results, model, embeddings, query_embedding):
|
59 |
+
"""
|
60 |
+
Take the 20 most similar articles, based on the relevant abstracts and
|
61 |
+
compare all the body texts content to the query
|
62 |
+
"""
|
63 |
+
|
64 |
+
all_text_distances = []
|
65 |
+
for top in results[0:20]:
|
66 |
+
article_idx = top[0]
|
67 |
+
|
68 |
+
body_texts = [text[0] for text in embeddings[article_idx][2]]
|
69 |
+
body_text_embeddings = model.encode(body_texts, show_progress_bar=False)
|
70 |
+
|
71 |
+
# body_text_distances = []
|
72 |
+
# for text_idx,text in enumerate(embeddings[article_idx][2]):
|
73 |
+
|
74 |
+
qbody = scipy.spatial.distance.cdist([query_embedding],
|
75 |
+
np.vstack(body_text_embeddings),
|
76 |
+
"cosine").reshape(-1,1)
|
77 |
+
|
78 |
+
body_text_distances = [(idx, 1-dist[0]) for idx,dist in enumerate(qbody)]
|
79 |
+
|
80 |
+
# for text_idx,text in enumerate(body_texts):
|
81 |
+
# # Encode only the body texts of 20 best articles
|
82 |
+
# # body_text_embedding = model.encode(text, show_progress_bar=False)
|
83 |
+
|
84 |
+
# body_text_distances.append(((text_idx,
|
85 |
+
# (1 - ([0]))
|
86 |
+
# )))
|
87 |
+
|
88 |
+
results = sorted(body_text_distances, key=lambda x: x[1], reverse=True)
|
89 |
+
|
90 |
+
if results:
|
91 |
+
all_text_distances.append((article_idx, results[0][0], results[0][1]))
|
92 |
+
|
93 |
+
results = sorted(all_text_distances, key=lambda x: x[2], reverse=True)
|
94 |
+
|
95 |
+
return results
|
96 |
+
|
97 |
+
|
98 |
+
def fetch_stage3(results, query, embeddings, comprehension_model, spacy_nlp):
|
99 |
+
"""
|
100 |
+
For the top 20 retrieved paragraphs in the document,
|
101 |
+
answer will be comprehended on each paragraph using the model.
|
102 |
+
"""
|
103 |
+
|
104 |
+
answers = []
|
105 |
+
|
106 |
+
# contxt = [embeddings[top_text[0]][2][top_text[1]][0] for top_text in results[0:20]]
|
107 |
+
|
108 |
+
for top_text in results[0:20]:
|
109 |
+
article_idx = top_text[0]
|
110 |
+
body_text_idx = top_text[1]
|
111 |
+
|
112 |
+
query_ = {"context": embeddings[article_idx][2][body_text_idx][0], "question": query}
|
113 |
+
pred = comprehension_model(query_, topk=1, show_progress_bar=False)
|
114 |
+
|
115 |
+
# If there is any answer
|
116 |
+
if pred["answer"] and round(pred["score"], 4) > 0:
|
117 |
+
# Take the suitable sentence from the paragraph
|
118 |
+
sent = get_full_sentence(spacy_nlp, query_['context'], pred["start"], pred["end"])
|
119 |
+
answers.append((article_idx, round(pred["score"], 4), sent))
|
120 |
+
|
121 |
+
results = sorted(answers, key=lambda x: x[1], reverse=True)
|
122 |
+
|
123 |
+
return results
|