awinml commited on
Commit
8cd1f1e
1 Parent(s): 6e53191

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +196 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm import tqdm
3
+ import pinecone
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import (
7
+ pipeline,
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ AutoModelForSeq2SeqLM,
11
+ )
12
+ import streamlit as st
13
+ import openai
14
+
15
+
16
+ # Initialize models from HuggingFace
17
+
18
+
19
+ @st.experimental_singleton
20
+ def get_t5_model():
21
+ return pipeline("summarization", model="t5-small", tokenizer="t5-small")
22
+
23
+
24
+ @st.experimental_singleton
25
+ def get_flan_t5_model():
26
+ return pipeline(
27
+ "summarization", model="google/flan-t5-small", tokenizer="google/flan-t5-small"
28
+ )
29
+
30
+
31
+ @st.experimental_singleton
32
+ def get_mpnet_embedding_model():
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ model = SentenceTransformer(
35
+ "sentence-transformers/all-mpnet-base-v2", device=device
36
+ )
37
+ model.max_seq_length = 512
38
+ return model
39
+
40
+
41
+ @st.experimental_singleton
42
+ def get_sgpt_embedding_model():
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ model = SentenceTransformer(
45
+ "Muennighoff/SGPT-125M-weightedmean-nli-bitfit", device=device
46
+ )
47
+ model.max_seq_length = 512
48
+ return model
49
+
50
+
51
+ @st.experimental_memo
52
+ def save_key(api_key):
53
+ return api_key
54
+
55
+
56
+ def query_pinecone(query, top_k, model, index):
57
+ # generate embeddings for the query
58
+ xq = model.encode([query]).tolist()
59
+ # search pinecone index for context passage with the answer
60
+ xc = index.query(xq, top_k=top_k, include_metadata=True)
61
+ return xc
62
+
63
+
64
+ def format_query(query_results):
65
+ # extract passage_text from Pinecone search result
66
+ context = [result["metadata"]["Text"] for result in query_results["matches"]]
67
+ return context
68
+
69
+
70
+ def gpt3_summary(text):
71
+ response = openai.Completion.create(
72
+ model="text-davinci-003",
73
+ prompt=text + "\n\nTl;dr",
74
+ temperature=0.1,
75
+ max_tokens=512,
76
+ top_p=1.0,
77
+ frequency_penalty=0.0,
78
+ presence_penalty=1,
79
+ )
80
+ return response.choices[0].text
81
+
82
+
83
+ def gpt3_qa(query, answer):
84
+ response = openai.Completion.create(
85
+ model="text-davinci-003",
86
+ prompt="Q: " + query + "\nA: " + answer,
87
+ temperature=0,
88
+ max_tokens=512,
89
+ top_p=1,
90
+ frequency_penalty=0.0,
91
+ presence_penalty=0.0,
92
+ stop=["\n"],
93
+ )
94
+ return response.choices[0].text
95
+
96
+
97
+ st.title("Abstractive Question Answering - APPL")
98
+
99
+ query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
100
+
101
+ num_results = int(st.number_input("Number of Results to query", 1, 5, value=2))
102
+
103
+
104
+ # Choose encoder model
105
+
106
+ encoder_models_choice = ["MPNET", "SGPT"]
107
+
108
+ encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
109
+
110
+
111
+ # Choose decoder model
112
+
113
+ decoder_models_choice = ["GPT3 (QA_davinci)", "GPT3 (text_davinci)", "T5", "FLAN-T5"]
114
+
115
+ decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
116
+
117
+
118
+ if encoder_model == "MPNET":
119
+ # Connect to pinecone environment
120
+ pinecone.init(
121
+ api_key="ea9fd320-6f8a-4edd-bf41-9e972b95cbf9", environment="us-east1-gcp"
122
+ )
123
+ pinecone_index_name = "week2-all-mpnet-base"
124
+ pinecone_index = pinecone.Index(pinecone_index_name)
125
+ retriever_model = get_mpnet_embedding_model()
126
+
127
+ elif encoder_model == "SGPT":
128
+ # Connect to pinecone environment
129
+ pinecone.init(
130
+ api_key="0d8215d7-4ad5-4c76-8c45-4a40c0f6a1b7", environment="us-east1-gcp"
131
+ )
132
+ pinecone_index_name = "week2-sgpt-125m"
133
+ pinecone_index = pinecone.Index(pinecone_index_name)
134
+ retriever_model = get_sgpt_embedding_model()
135
+
136
+
137
+ query_results = query_pinecone(query_text, num_results, retriever_model, pinecone_index)
138
+
139
+ context_list = format_query(query_results)
140
+
141
+
142
+ st.subheader("Answer:")
143
+
144
+
145
+ if decoder_model == "GPT3 (text_davinci)":
146
+ openai_key = st.text_input(
147
+ "Enter OpenAI key",
148
+ value="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",
149
+ type="password",
150
+ )
151
+ api_key = save_key(openai_key)
152
+ openai.api_key = api_key
153
+ output_text = []
154
+ for context_text in context_list:
155
+ output_text.append(gpt3_summary(context_text))
156
+ generated_text = " ".join(output_text)
157
+ st.write(gpt3_summary(generated_text))
158
+
159
+ elif decoder_model == "GPT3 - QA":
160
+ openai_key = st.text_input(
161
+ "Enter OpenAI key",
162
+ value="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",
163
+ type="password",
164
+ )
165
+ api_key = save_key(openai_key)
166
+ openai.api_key = api_key
167
+ output_text = []
168
+ for context_text in context_list:
169
+ output_text.append(gpt3_qa(query_text, context_text))
170
+ generated_text = " ".join(output_text)
171
+ st.write(gpt3_qa(query_text, generated_text))
172
+
173
+ elif decoder_model == "T5":
174
+ t5_pipeline = get_t5_model()
175
+ output_text = []
176
+ for context_text in context_list:
177
+ output_text.append(t5_pipeline(context_text)[0]["summary_text"])
178
+ generated_text = " ".join(output_text)
179
+ st.write(t5_pipeline(generated_text)[0]["summary_text"])
180
+
181
+ elif decoder_model == "FLAN-T5":
182
+ flan_t5_pipeline = get_flan_t5_model()
183
+ output_text = []
184
+ for context_text in context_list:
185
+ output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
186
+ generated_text = " ".join(output_text)
187
+ st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
188
+
189
+ show_retrieved_text = st.checkbox("Show Retrieved Text", value=False)
190
+
191
+ if show_retrieved_text:
192
+
193
+ st.subheader("Retrieved Text:")
194
+
195
+ for context_text in context_list:
196
+ st.markdown(f"- {context_text}")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ tqdm
3
+ pinecone-client
4
+ torch
5
+ sentence_transformers
6
+ transformers
7
+ streamlit
8
+ openai