hitz02 commited on
Commit
07ab211
1 Parent(s): 556f46d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +109 -0
  2. requirements.txt +9 -0
  3. utils.py +123 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pickle
5
+ import glob
6
+ import json
7
+ from pandas.io.json import json_normalize
8
+ from nltk.tokenize import sent_tokenize
9
+ import nltk
10
+ import scipy.spatial
11
+ from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
12
+ from sentence_transformers import models, SentenceTransformer
13
+ import torch
14
+ import spacy
15
+ import streamlit as st
16
+ from utils import *
17
+
18
+
19
+ @st.cache(allow_output_mutation=True)
20
+ def load_prep_data():
21
+ with open('listfile_3.data', 'rb') as filehandle:
22
+ articles = pickle.load(filehandle)
23
+
24
+ for article in range(len(articles)):
25
+ if articles[article][1] != []:
26
+ articles[article][1] = sent_tokenize(articles[article][1])
27
+
28
+ return articles
29
+
30
+ @st.cache(allow_output_mutation=True)
31
+ def build_sent_trans_model():
32
+
33
+ word_embedding_model = models.BERT('covidbert_nli')
34
+
35
+ # Add the pooling strategy of Mean
36
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
37
+ pooling_mode_mean_tokens=True,
38
+ pooling_mode_cls_token=False,
39
+ pooling_mode_max_tokens=False)
40
+
41
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
42
+ return model
43
+
44
+ @st.cache(allow_output_mutation=True)
45
+ def load_embedded_articles():
46
+ with open('list_of_articles.pkl', 'rb') as f:
47
+ list_of_articles = pickle.load(f)
48
+
49
+ return list_of_articles
50
+
51
+ @st.cache(allow_output_mutation=True)
52
+ def load_comprehension_model():
53
+ # device is set to -1 to use the available gpu
54
+ comprehension_model = pipeline("question-answering",
55
+ model=AutoModelForQuestionAnswering.\
56
+ from_pretrained("graviraja/covidbert_squad"),
57
+ tokenizer=AutoTokenizer.\
58
+ from_pretrained("graviraja/covidbert_squad"),
59
+ device=-1)
60
+
61
+ return comprehension_model
62
+
63
+
64
+
65
+ def main():
66
+
67
+ nltk.download('punkt')
68
+ spacy_nlp = spacy.load('en_core_web_sm')
69
+
70
+ device = torch.device('cuda:0' if torch.cuda.is_available()
71
+ else 'cpu')
72
+
73
+ embeddings = load_prep_data()
74
+
75
+ model = build_sent_trans_model()
76
+ model.to(device)
77
+
78
+ list_of_articles = load_embedded_articles()
79
+
80
+ comprehension_model = load_comprehension_model()
81
+
82
+ query = st.text_input("Enter Query",'example query ',key="query")
83
+
84
+ query_embedding, results1 = fetch_stage1(query, model, list_of_articles)
85
+
86
+ results2 = fetch_stage2(results1, model, embeddings, query_embedding)
87
+
88
+ results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)
89
+
90
+ if results3:
91
+ count = 1
92
+
93
+ for res in results3:
94
+ st.write('{}> {}'.format(count, res[2]))
95
+ st.write('Score: %.4f' % (res[1]))
96
+ st.write("From the article with title: {}".format(embeddings[res[0]][0]))
97
+ st.write("\n")
98
+ # print(count,". ", res[2], "(Score: %.4f)" % (res[1]))
99
+ # print("From the article with title: ", embeddings[res[0]][0])
100
+ # print("\n")
101
+ if count > 3:
102
+ break
103
+ count += 1
104
+ else:
105
+ st.info("There isn't any answer")
106
+
107
+
108
+ if __name__ == '__main__':
109
+ main()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ nltk
3
+ pandas
4
+ scipy
5
+ numpy
6
+ sentence-transformers==0.2.5.1
7
+ transformers==2.5.1
8
+ spacy
9
+ streamlit
utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pickle
5
+ import glob
6
+ import json
7
+ from pandas.io.json import json_normalize
8
+ from nltk.tokenize import sent_tokenize
9
+ import nltk
10
+ import scipy.spatial
11
+ from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
12
+ from sentence_transformers import models, SentenceTransformer
13
+
14
+
15
+ def get_full_sentence(spacy_nlp, para_text, start_index, end_index):
16
+ """
17
+ Returns the relative sentence of original text,
18
+ given a specific paragraph (body text).
19
+ """
20
+ sent_start = 0
21
+ sent_end = len(para_text)
22
+ for sent in spacy_nlp(para_text).sents:
23
+ if (sent.start_char <= start_index) and (sent.end_char >= start_index):
24
+ sent_start = sent.start_char
25
+ if (sent.start_char <= end_index) and (sent.end_char >= end_index):
26
+ sent_end = sent.end_char
27
+ sentence = para_text[sent_start:sent_end + 1]
28
+ return sentence
29
+
30
+
31
+ def fetch_stage1(query, model, list_of_articles):
32
+ """
33
+ Compare all the articles' abstract content with each query
34
+ """
35
+
36
+ # Encode queries
37
+ query_embedding = model.encode([query])[0]
38
+
39
+
40
+ all_abs_distances = []
41
+
42
+ for idx_of_article,article in enumerate(list_of_articles):
43
+ if article:
44
+ distances = []
45
+ cdists = scipy.spatial.distance.cdist([query_embedding], np.vstack(article), "cosine").reshape(-1,1)
46
+ for idx,sentence in enumerate(article):
47
+ distances.append((idx, 1 - cdists[idx][0]))
48
+
49
+ results = sorted(distances, key=lambda x: x[1], reverse=True)
50
+ if results:
51
+ all_abs_distances.append((idx_of_article, results[0][0], results[0][1]))
52
+
53
+ results = sorted(all_abs_distances, key=lambda x: x[2], reverse=True)
54
+
55
+ return query_embedding, results
56
+
57
+
58
+ def fetch_stage2(results, model, embeddings, query_embedding):
59
+ """
60
+ Take the 20 most similar articles, based on the relevant abstracts and
61
+ compare all the body texts content to the query
62
+ """
63
+
64
+ all_text_distances = []
65
+ for top in results[0:20]:
66
+ article_idx = top[0]
67
+
68
+ body_texts = [text[0] for text in embeddings[article_idx][2]]
69
+ body_text_embeddings = model.encode(body_texts, show_progress_bar=False)
70
+
71
+ # body_text_distances = []
72
+ # for text_idx,text in enumerate(embeddings[article_idx][2]):
73
+
74
+ qbody = scipy.spatial.distance.cdist([query_embedding],
75
+ np.vstack(body_text_embeddings),
76
+ "cosine").reshape(-1,1)
77
+
78
+ body_text_distances = [(idx, 1-dist[0]) for idx,dist in enumerate(qbody)]
79
+
80
+ # for text_idx,text in enumerate(body_texts):
81
+ # # Encode only the body texts of 20 best articles
82
+ # # body_text_embedding = model.encode(text, show_progress_bar=False)
83
+
84
+ # body_text_distances.append(((text_idx,
85
+ # (1 - ([0]))
86
+ # )))
87
+
88
+ results = sorted(body_text_distances, key=lambda x: x[1], reverse=True)
89
+
90
+ if results:
91
+ all_text_distances.append((article_idx, results[0][0], results[0][1]))
92
+
93
+ results = sorted(all_text_distances, key=lambda x: x[2], reverse=True)
94
+
95
+ return results
96
+
97
+
98
+ def fetch_stage3(results, query, embeddings, comprehension_model, spacy_nlp):
99
+ """
100
+ For the top 20 retrieved paragraphs in the document,
101
+ answer will be comprehended on each paragraph using the model.
102
+ """
103
+
104
+ answers = []
105
+
106
+ # contxt = [embeddings[top_text[0]][2][top_text[1]][0] for top_text in results[0:20]]
107
+
108
+ for top_text in results[0:20]:
109
+ article_idx = top_text[0]
110
+ body_text_idx = top_text[1]
111
+
112
+ query_ = {"context": embeddings[article_idx][2][body_text_idx][0], "question": query}
113
+ pred = comprehension_model(query_, topk=1, show_progress_bar=False)
114
+
115
+ # If there is any answer
116
+ if pred["answer"] and round(pred["score"], 4) > 0:
117
+ # Take the suitable sentence from the paragraph
118
+ sent = get_full_sentence(spacy_nlp, query_['context'], pred["start"], pred["end"])
119
+ answers.append((article_idx, round(pred["score"], 4), sent))
120
+
121
+ results = sorted(answers, key=lambda x: x[1], reverse=True)
122
+
123
+ return results