Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ from haystack.document_stores import FAISSDocumentStore
|
|
12 |
from haystack.nodes import EmbeddingRetriever
|
13 |
from haystack.schema import Document
|
14 |
from huggingface_hub import login, HfApi, hf_hub_download, InferenceClient
|
|
|
15 |
|
16 |
# Get HF token
|
17 |
hf_token = os.environ["HF_TOKEN"]
|
@@ -19,18 +20,30 @@ login(token=hf_token, add_to_git_credential=True)
|
|
19 |
|
20 |
# Get openai API key
|
21 |
openai_key = os.environ["OPENAI_API_KEY"]
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
-
template = PromptTemplate(
|
25 |
-
|
26 |
-
Answer the given question using the following documents. \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
Formulate your answer in the style of an academic report. \
|
28 |
Provide example quotes and citations using extracted text from the documents. \
|
29 |
Use facts and numbers from the documents in your answer. \
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
)
|
34 |
|
35 |
# Create a list of options for the dropdown
|
36 |
model_options = ['chatGPT','Llama2']
|
@@ -39,7 +52,7 @@ model_options = ['chatGPT','Llama2']
|
|
39 |
country_options = ['All Countries','Angola','Botswana','Lesotho','Kenya','Malawi','Mozambique','Namibia','Rwanda','South Africa','Zambia','Zimbabwe']
|
40 |
|
41 |
# Create a list of options for the dropdown
|
42 |
-
|
43 |
|
44 |
# List of examples
|
45 |
examples = [
|
@@ -48,40 +61,74 @@ examples = [
|
|
48 |
"In addition to gender, children, and youth, is there any mention of other groups facing disproportional impacts from climate change due to their geographic location, socio-economic status, age, gender, health, and occupation?"
|
49 |
]
|
50 |
|
51 |
-
def get_docs(input_query, country = None):
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
else:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
|
86 |
def get_refs(docs, res):
|
87 |
'''
|
@@ -107,26 +154,22 @@ def get_refs(docs, res):
|
|
107 |
return result_str
|
108 |
|
109 |
# define a special function for putting the prompt together (as we can't use haystack)
|
110 |
-
def
|
111 |
-
base_prompt=
|
112 |
-
Formulate your answer in the style of an academic report. \
|
113 |
-
Provide example quotes and citations using extracted text from the documents. \
|
114 |
-
Use facts and numbers from the documents in your answer. \
|
115 |
-
ALWAYS include references for information used from documents at the end of each applicable sentence using the format: '[ref. #]', where '[ref. #]' is included in the text provided at the start of each document (demarcated by the pattern '- &&& [ref. #] document_name &&&:')'. \
|
116 |
-
Do not include page numbers in the references. \
|
117 |
-
If no relevant information to answer the question is present in the documents, just say you don't have enough information to answer."
|
118 |
# Add the meta data for references
|
119 |
-
context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['
|
120 |
prompt = base_prompt+"; Context: "+context+"; Question: "+query+"; Answer:"
|
121 |
return(prompt)
|
122 |
|
123 |
def run_query(input_text, country, model_sel):
|
124 |
-
docs = get_docs(input_text, country)
|
|
|
125 |
# st.write('Selected country: ', country) # Debugging country
|
126 |
if model_sel == "chatGPT":
|
127 |
-
res = pipe.run(query=input_text, documents=docs)
|
|
|
128 |
output = res["results"][0]
|
129 |
-
references = get_refs(docs,
|
130 |
else:
|
131 |
res = client.text_generation(get_prompt_llama2(docs, query=input_text), max_new_tokens=4000, temperature=0.01, model=model)
|
132 |
output = res
|
@@ -137,17 +180,30 @@ def run_query(input_text, country, model_sel):
|
|
137 |
st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
|
138 |
st.markdown(references, unsafe_allow_html=True)
|
139 |
|
140 |
-
# Setup retriever, pulling from local faiss datastore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
retriever = EmbeddingRetriever(
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
),
|
146 |
-
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
147 |
-
model_format="sentence_transformers",
|
148 |
-
progress_bar=False,
|
149 |
)
|
150 |
|
|
|
151 |
with st.sidebar:
|
152 |
# Dropdown selectbox
|
153 |
country = st.sidebar.selectbox('Select a country:', country_options)
|
@@ -157,7 +213,7 @@ with st.sidebar:
|
|
157 |
* *Then be sure to mention the country names of interest in your query*
|
158 |
"""
|
159 |
)
|
160 |
-
vulnerabilities_cat = st.sidebar.selectbox('Select a vulnerabilities category:',
|
161 |
# choice = st.sidebar.radio(label = 'Select the Document',
|
162 |
# help = 'You can upload the document \
|
163 |
# or else you can try a example document',
|
@@ -189,12 +245,12 @@ model_sel = "chatGPT"
|
|
189 |
#----Model Select logic-------
|
190 |
if model_sel == "chatGPT":
|
191 |
model_name = "gpt-3.5-turbo"
|
192 |
-
# Initialize the PromptNode
|
193 |
-
pn = PromptNode(model_name_or_path=model_name, default_prompt_template=template, api_key=openai_key, max_length=2000, model_kwargs={"generation_kwargs": {"do_sample": False, "temperature": 0}})
|
194 |
|
195 |
-
# Initialize the pipeline
|
196 |
-
pipe = Pipeline()
|
197 |
-
pipe.add_node(component=pn, name="prompt_node", inputs=["Query"])
|
198 |
else:
|
199 |
# Currently disabled
|
200 |
model = "meta-llama/Llama-2-70b-chat-hf"
|
|
|
12 |
from haystack.nodes import EmbeddingRetriever
|
13 |
from haystack.schema import Document
|
14 |
from huggingface_hub import login, HfApi, hf_hub_download, InferenceClient
|
15 |
+
import openai
|
16 |
|
17 |
# Get HF token
|
18 |
hf_token = os.environ["HF_TOKEN"]
|
|
|
20 |
|
21 |
# Get openai API key
|
22 |
openai_key = os.environ["OPENAI_API_KEY"]
|
23 |
+
openai.api_key = openai_api_key
|
24 |
+
|
25 |
+
# Get openai API key
|
26 |
+
pinecone_key = os.environ["PINECONE_API_KEY"]
|
27 |
|
28 |
|
29 |
+
# template = PromptTemplate(
|
30 |
+
# prompt="""
|
31 |
+
# Answer the given question using the following documents. \
|
32 |
+
# Formulate your answer in the style of an academic report. \
|
33 |
+
# Provide example quotes and citations using extracted text from the documents. \
|
34 |
+
# Use facts and numbers from the documents in your answer. \
|
35 |
+
# Reference information used from documents at the end of each applicable sentence (ex: [source: document_name]), where 'document_name' is the text provided at the start of each document (demarcated by '- &&&' and '&&&:')'. \
|
36 |
+
# If no relevant information to answer the question is present in the documents, just say you don't have enough information to answer. \
|
37 |
+
# Context: {' - '.join(['&&& '+d.meta['document']+' ref. '+str(d.meta['ref_id'])+' &&&: '+d.content for d in documents])}; Question: {query}; Answer:""",
|
38 |
+
# )
|
39 |
+
|
40 |
+
prompt_template="Answer the given question using the following documents. \
|
41 |
Formulate your answer in the style of an academic report. \
|
42 |
Provide example quotes and citations using extracted text from the documents. \
|
43 |
Use facts and numbers from the documents in your answer. \
|
44 |
+
ALWAYS include references for information used from documents at the end of each applicable sentence using the format: '[ref. #]', where '[ref. #]' is included in the text provided at the start of each document (demarcated by the pattern '- &&& [ref. #] document_name &&&:')'. \
|
45 |
+
Do not include page numbers in the references. \
|
46 |
+
If no relevant information to answer the question is present in the documents, just say you don't have enough information to answer."
|
|
|
47 |
|
48 |
# Create a list of options for the dropdown
|
49 |
model_options = ['chatGPT','Llama2']
|
|
|
52 |
country_options = ['All Countries','Angola','Botswana','Lesotho','Kenya','Malawi','Mozambique','Namibia','Rwanda','South Africa','Zambia','Zimbabwe']
|
53 |
|
54 |
# Create a list of options for the dropdown
|
55 |
+
vulnerability_options = ['All Categories','Agricultural communities', 'Children', 'Coastal communities', 'Ethnic, racial or other minorities', 'Fishery communities', 'Informal sector workers', 'Members of indigenous and local communities', 'Migrants and displaced persons', 'Older persons', 'Persons living in poverty', 'Persons with disabilities', 'Persons with pre-existing health conditions', 'Residents of drought-prone regions', 'Rural populations', 'Sexual minorities (LGBTQI+)', 'Urban populations', 'Women and other genders','Other']
|
56 |
|
57 |
# List of examples
|
58 |
examples = [
|
|
|
61 |
"In addition to gender, children, and youth, is there any mention of other groups facing disproportional impacts from climate change due to their geographic location, socio-economic status, age, gender, health, and occupation?"
|
62 |
]
|
63 |
|
64 |
+
# def get_docs(input_query, country = None):
|
65 |
+
# '''
|
66 |
+
# Construct a hacky query to focus the retriever on the target country (see notes below)
|
67 |
+
# We take the top 150 k because we want to make sure we have 10 pertaining to the selected country
|
68 |
+
# '''
|
69 |
+
# if country == 'All Countries':
|
70 |
+
# query = input_query
|
71 |
+
# else:
|
72 |
+
# query = "For the country of "+country+", "+input_query
|
73 |
+
# # Retrieve top k documents
|
74 |
+
# docs = retriever.retrieve(query=query,top_k = 150)
|
75 |
+
# # Break out the key fields and convert to pandas for filtering
|
76 |
+
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
|
77 |
+
# df_docs = pd.DataFrame(docs)
|
78 |
+
# if country != 'All Countries':
|
79 |
+
# df_docs = df_docs.query('country in @country')
|
80 |
+
# # Take the top 10
|
81 |
+
# df_docs = df_docs.head(10)
|
82 |
+
# # Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
|
83 |
+
# df_docs = df_docs.reset_index()
|
84 |
+
# df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
|
85 |
+
# # Convert back to Document format
|
86 |
+
# ls_dict = []
|
87 |
+
# # Iterate over df and add relevant fields to the dict object
|
88 |
+
# for index, row in df_docs.iterrows():
|
89 |
+
# # Create a Document object for each row
|
90 |
+
# doc = Document(
|
91 |
+
# row['content'],
|
92 |
+
# meta={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'score': row['score']}
|
93 |
+
# )
|
94 |
+
|
95 |
+
# # Append the Document object to the documents list
|
96 |
+
# ls_dict.append(doc)
|
97 |
+
# return(ls_dict)
|
98 |
+
|
99 |
+
def get_docs(input_query, country = [], vulnerability_cat = []):
|
100 |
+
if not country:
|
101 |
+
country = "All Countries"
|
102 |
+
if not vulnerability_cat:
|
103 |
+
if country == "All Countries":
|
104 |
+
filters = None
|
105 |
else:
|
106 |
+
filters = {'country': {'$in': country}}
|
107 |
+
else:
|
108 |
+
if country == "All Countries":
|
109 |
+
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
|
110 |
+
else:
|
111 |
+
filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
|
112 |
+
docs = retriever.retrieve(query=query, filters = filters, top_k = 10)
|
113 |
+
# Break out the key fields and convert to pandas for filtering
|
114 |
+
docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
|
115 |
+
df_docs = pd.DataFrame(docs)
|
116 |
+
# Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
|
117 |
+
df_docs = df_docs.reset_index()
|
118 |
+
df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
|
119 |
+
# Convert back to Document format
|
120 |
+
ls_dict = []
|
121 |
+
# Iterate over df and add relevant fields to the dict object
|
122 |
+
for index, row in df_docs.iterrows():
|
123 |
+
# Create a Document object for each row
|
124 |
+
doc = Document(
|
125 |
+
row['content'],
|
126 |
+
meta={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'vulnerability_cat': row['vulnerability_cat'], 'score': row['score']}
|
127 |
+
)
|
128 |
+
|
129 |
+
# Append the Document object to the documents list
|
130 |
+
ls_dict.append(doc)
|
131 |
+
return ls_dict
|
132 |
|
133 |
def get_refs(docs, res):
|
134 |
'''
|
|
|
154 |
return result_str
|
155 |
|
156 |
# define a special function for putting the prompt together (as we can't use haystack)
|
157 |
+
def get_prompt(docs, query):
|
158 |
+
base_prompt=prompt_template
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
# Add the meta data for references
|
160 |
+
context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
|
161 |
prompt = base_prompt+"; Context: "+context+"; Question: "+query+"; Answer:"
|
162 |
return(prompt)
|
163 |
|
164 |
def run_query(input_text, country, model_sel):
|
165 |
+
# docs = get_docs(input_text, country)
|
166 |
+
docs = get_docs(query, country=country,vulnerability_cat=vulnerability_options)
|
167 |
# st.write('Selected country: ', country) # Debugging country
|
168 |
if model_sel == "chatGPT":
|
169 |
+
# res = pipe.run(query=input_text, documents=docs)
|
170 |
+
res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, query=query)}])
|
171 |
output = res["results"][0]
|
172 |
+
references = get_refs(docs, output)
|
173 |
else:
|
174 |
res = client.text_generation(get_prompt_llama2(docs, query=input_text), max_new_tokens=4000, temperature=0.01, model=model)
|
175 |
output = res
|
|
|
180 |
st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
|
181 |
st.markdown(references, unsafe_allow_html=True)
|
182 |
|
183 |
+
# # Setup retriever, pulling from local faiss datastore
|
184 |
+
# retriever = EmbeddingRetriever(
|
185 |
+
# document_store=FAISSDocumentStore.load(
|
186 |
+
# index_path="./cpv_full_southern_africa_kenya.faiss",
|
187 |
+
# config_path="./cpv_full_southern_africa_kenya.json",
|
188 |
+
# ),
|
189 |
+
# embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
190 |
+
# model_format="sentence_transformers",
|
191 |
+
# progress_bar=False,
|
192 |
+
# )
|
193 |
+
|
194 |
+
# Setup retriever, pulling from pinecone
|
195 |
+
doc_file_name="cpv_full_southern_africa"
|
196 |
+
document_store = PineconeDocumentStore(api_key=pinecone_key,
|
197 |
+
environment="asia-southeast1-gcp-free",
|
198 |
+
index=doc_file_name)
|
199 |
+
|
200 |
retriever = EmbeddingRetriever(
|
201 |
+
document_store=document_store,
|
202 |
+
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
203 |
+
model_format="sentence_transformers"
|
|
|
|
|
|
|
|
|
204 |
)
|
205 |
|
206 |
+
|
207 |
with st.sidebar:
|
208 |
# Dropdown selectbox
|
209 |
country = st.sidebar.selectbox('Select a country:', country_options)
|
|
|
213 |
* *Then be sure to mention the country names of interest in your query*
|
214 |
"""
|
215 |
)
|
216 |
+
vulnerabilities_cat = st.sidebar.selectbox('Select a vulnerabilities category:', vulnerability_options)
|
217 |
# choice = st.sidebar.radio(label = 'Select the Document',
|
218 |
# help = 'You can upload the document \
|
219 |
# or else you can try a example document',
|
|
|
245 |
#----Model Select logic-------
|
246 |
if model_sel == "chatGPT":
|
247 |
model_name = "gpt-3.5-turbo"
|
248 |
+
# # Initialize the PromptNode
|
249 |
+
# pn = PromptNode(model_name_or_path=model_name, default_prompt_template=template, api_key=openai_key, max_length=2000, model_kwargs={"generation_kwargs": {"do_sample": False, "temperature": 0}})
|
250 |
|
251 |
+
# # Initialize the pipeline
|
252 |
+
# pipe = Pipeline()
|
253 |
+
# pipe.add_node(component=pn, name="prompt_node", inputs=["Query"])
|
254 |
else:
|
255 |
# Currently disabled
|
256 |
model = "meta-llama/Llama-2-70b-chat-hf"
|