mtyrrell commited on
Commit
430b690
1 Parent(s): 7a50275

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -67
app.py CHANGED
@@ -12,6 +12,7 @@ from haystack.document_stores import FAISSDocumentStore
12
  from haystack.nodes import EmbeddingRetriever
13
  from haystack.schema import Document
14
  from huggingface_hub import login, HfApi, hf_hub_download, InferenceClient
 
15
 
16
  # Get HF token
17
  hf_token = os.environ["HF_TOKEN"]
@@ -19,18 +20,30 @@ login(token=hf_token, add_to_git_credential=True)
19
 
20
  # Get openai API key
21
  openai_key = os.environ["OPENAI_API_KEY"]
 
 
 
 
22
 
23
 
24
- template = PromptTemplate(
25
- prompt="""
26
- Answer the given question using the following documents. \
 
 
 
 
 
 
 
 
 
27
  Formulate your answer in the style of an academic report. \
28
  Provide example quotes and citations using extracted text from the documents. \
29
  Use facts and numbers from the documents in your answer. \
30
- Reference information used from documents at the end of each applicable sentence (ex: [source: document_name]), where 'document_name' is the text provided at the start of each document (demarcated by '- &&&' and '&&&:')'. \
31
- If no relevant information to answer the question is present in the documents, just say you don't have enough information to answer. \
32
- Context: {' - '.join(['&&& '+d.meta['document']+' ref. '+str(d.meta['ref_id'])+' &&&: '+d.content for d in documents])}; Question: {query}; Answer:""",
33
- )
34
 
35
  # Create a list of options for the dropdown
36
  model_options = ['chatGPT','Llama2']
@@ -39,7 +52,7 @@ model_options = ['chatGPT','Llama2']
39
  country_options = ['All Countries','Angola','Botswana','Lesotho','Kenya','Malawi','Mozambique','Namibia','Rwanda','South Africa','Zambia','Zimbabwe']
40
 
41
  # Create a list of options for the dropdown
42
- class_options = ['All Categories','Agricultural communities', 'Children', 'Coastal communities', 'Ethnic, racial or other minorities', 'Fishery communities', 'Informal sector workers', 'Members of indigenous and local communities', 'Migrants and displaced persons', 'Older persons', 'Persons living in poverty', 'Persons with disabilities', 'Persons with pre-existing health conditions', 'Residents of drought-prone regions', 'Rural populations', 'Sexual minorities (LGBTQI+)', 'Urban populations', 'Women and other genders','Other']
43
 
44
  # List of examples
45
  examples = [
@@ -48,40 +61,74 @@ examples = [
48
  "In addition to gender, children, and youth, is there any mention of other groups facing disproportional impacts from climate change due to their geographic location, socio-economic status, age, gender, health, and occupation?"
49
  ]
50
 
51
- def get_docs(input_query, country = None):
52
- '''
53
- Construct a hacky query to focus the retriever on the target country (see notes below)
54
- We take the top 150 k because we want to make sure we have 10 pertaining to the selected country
55
- '''
56
- if country == 'All Countries':
57
- query = input_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  else:
59
- query = "For the country of "+country+", "+input_query
60
- # Retrieve top k documents
61
- docs = retriever.retrieve(query=query,top_k = 150)
62
- # Break out the key fields and convert to pandas for filtering
63
- docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
64
- df_docs = pd.DataFrame(docs)
65
- if country != 'All Countries':
66
- df_docs = df_docs.query('country in @country')
67
- # Take the top 10
68
- df_docs = df_docs.head(10)
69
- # Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
70
- df_docs = df_docs.reset_index()
71
- df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
72
- # Convert back to Document format
73
- ls_dict = []
74
- # Iterate over df and add relevant fields to the dict object
75
- for index, row in df_docs.iterrows():
76
- # Create a Document object for each row
77
- doc = Document(
78
- row['content'],
79
- meta={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'score': row['score']}
80
- )
81
-
82
- # Append the Document object to the documents list
83
- ls_dict.append(doc)
84
- return(ls_dict)
85
 
86
  def get_refs(docs, res):
87
  '''
@@ -107,26 +154,22 @@ def get_refs(docs, res):
107
  return result_str
108
 
109
  # define a special function for putting the prompt together (as we can't use haystack)
110
- def get_prompt_llama2(docs, query):
111
- base_prompt="Answer the given question using the following documents. \
112
- Formulate your answer in the style of an academic report. \
113
- Provide example quotes and citations using extracted text from the documents. \
114
- Use facts and numbers from the documents in your answer. \
115
- ALWAYS include references for information used from documents at the end of each applicable sentence using the format: '[ref. #]', where '[ref. #]' is included in the text provided at the start of each document (demarcated by the pattern '- &&& [ref. #] document_name &&&:')'. \
116
- Do not include page numbers in the references. \
117
- If no relevant information to answer the question is present in the documents, just say you don't have enough information to answer."
118
  # Add the meta data for references
119
- context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document_name']+' &&&: '+d.content for d in docs])
120
  prompt = base_prompt+"; Context: "+context+"; Question: "+query+"; Answer:"
121
  return(prompt)
122
 
123
  def run_query(input_text, country, model_sel):
124
- docs = get_docs(input_text, country)
 
125
  # st.write('Selected country: ', country) # Debugging country
126
  if model_sel == "chatGPT":
127
- res = pipe.run(query=input_text, documents=docs)
 
128
  output = res["results"][0]
129
- references = get_refs(docs, res["results"][0])
130
  else:
131
  res = client.text_generation(get_prompt_llama2(docs, query=input_text), max_new_tokens=4000, temperature=0.01, model=model)
132
  output = res
@@ -137,17 +180,30 @@ def run_query(input_text, country, model_sel):
137
  st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
138
  st.markdown(references, unsafe_allow_html=True)
139
 
140
- # Setup retriever, pulling from local faiss datastore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  retriever = EmbeddingRetriever(
142
- document_store=FAISSDocumentStore.load(
143
- index_path="./cpv_full_southern_africa_kenya.faiss",
144
- config_path="./cpv_full_southern_africa_kenya.json",
145
- ),
146
- embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
147
- model_format="sentence_transformers",
148
- progress_bar=False,
149
  )
150
 
 
151
  with st.sidebar:
152
  # Dropdown selectbox
153
  country = st.sidebar.selectbox('Select a country:', country_options)
@@ -157,7 +213,7 @@ with st.sidebar:
157
  * *Then be sure to mention the country names of interest in your query*
158
  """
159
  )
160
- vulnerabilities_cat = st.sidebar.selectbox('Select a vulnerabilities category:', class_options)
161
  # choice = st.sidebar.radio(label = 'Select the Document',
162
  # help = 'You can upload the document \
163
  # or else you can try a example document',
@@ -189,12 +245,12 @@ model_sel = "chatGPT"
189
  #----Model Select logic-------
190
  if model_sel == "chatGPT":
191
  model_name = "gpt-3.5-turbo"
192
- # Initialize the PromptNode
193
- pn = PromptNode(model_name_or_path=model_name, default_prompt_template=template, api_key=openai_key, max_length=2000, model_kwargs={"generation_kwargs": {"do_sample": False, "temperature": 0}})
194
 
195
- # Initialize the pipeline
196
- pipe = Pipeline()
197
- pipe.add_node(component=pn, name="prompt_node", inputs=["Query"])
198
  else:
199
  # Currently disabled
200
  model = "meta-llama/Llama-2-70b-chat-hf"
 
12
  from haystack.nodes import EmbeddingRetriever
13
  from haystack.schema import Document
14
  from huggingface_hub import login, HfApi, hf_hub_download, InferenceClient
15
+ import openai
16
 
17
  # Get HF token
18
  hf_token = os.environ["HF_TOKEN"]
 
20
 
21
  # Get openai API key
22
  openai_key = os.environ["OPENAI_API_KEY"]
23
+ openai.api_key = openai_api_key
24
+
25
+ # Get openai API key
26
+ pinecone_key = os.environ["PINECONE_API_KEY"]
27
 
28
 
29
+ # template = PromptTemplate(
30
+ # prompt="""
31
+ # Answer the given question using the following documents. \
32
+ # Formulate your answer in the style of an academic report. \
33
+ # Provide example quotes and citations using extracted text from the documents. \
34
+ # Use facts and numbers from the documents in your answer. \
35
+ # Reference information used from documents at the end of each applicable sentence (ex: [source: document_name]), where 'document_name' is the text provided at the start of each document (demarcated by '- &&&' and '&&&:')'. \
36
+ # If no relevant information to answer the question is present in the documents, just say you don't have enough information to answer. \
37
+ # Context: {' - '.join(['&&& '+d.meta['document']+' ref. '+str(d.meta['ref_id'])+' &&&: '+d.content for d in documents])}; Question: {query}; Answer:""",
38
+ # )
39
+
40
+ prompt_template="Answer the given question using the following documents. \
41
  Formulate your answer in the style of an academic report. \
42
  Provide example quotes and citations using extracted text from the documents. \
43
  Use facts and numbers from the documents in your answer. \
44
+ ALWAYS include references for information used from documents at the end of each applicable sentence using the format: '[ref. #]', where '[ref. #]' is included in the text provided at the start of each document (demarcated by the pattern '- &&& [ref. #] document_name &&&:')'. \
45
+ Do not include page numbers in the references. \
46
+ If no relevant information to answer the question is present in the documents, just say you don't have enough information to answer."
 
47
 
48
  # Create a list of options for the dropdown
49
  model_options = ['chatGPT','Llama2']
 
52
  country_options = ['All Countries','Angola','Botswana','Lesotho','Kenya','Malawi','Mozambique','Namibia','Rwanda','South Africa','Zambia','Zimbabwe']
53
 
54
  # Create a list of options for the dropdown
55
+ vulnerability_options = ['All Categories','Agricultural communities', 'Children', 'Coastal communities', 'Ethnic, racial or other minorities', 'Fishery communities', 'Informal sector workers', 'Members of indigenous and local communities', 'Migrants and displaced persons', 'Older persons', 'Persons living in poverty', 'Persons with disabilities', 'Persons with pre-existing health conditions', 'Residents of drought-prone regions', 'Rural populations', 'Sexual minorities (LGBTQI+)', 'Urban populations', 'Women and other genders','Other']
56
 
57
  # List of examples
58
  examples = [
 
61
  "In addition to gender, children, and youth, is there any mention of other groups facing disproportional impacts from climate change due to their geographic location, socio-economic status, age, gender, health, and occupation?"
62
  ]
63
 
64
+ # def get_docs(input_query, country = None):
65
+ # '''
66
+ # Construct a hacky query to focus the retriever on the target country (see notes below)
67
+ # We take the top 150 k because we want to make sure we have 10 pertaining to the selected country
68
+ # '''
69
+ # if country == 'All Countries':
70
+ # query = input_query
71
+ # else:
72
+ # query = "For the country of "+country+", "+input_query
73
+ # # Retrieve top k documents
74
+ # docs = retriever.retrieve(query=query,top_k = 150)
75
+ # # Break out the key fields and convert to pandas for filtering
76
+ # docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
77
+ # df_docs = pd.DataFrame(docs)
78
+ # if country != 'All Countries':
79
+ # df_docs = df_docs.query('country in @country')
80
+ # # Take the top 10
81
+ # df_docs = df_docs.head(10)
82
+ # # Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
83
+ # df_docs = df_docs.reset_index()
84
+ # df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
85
+ # # Convert back to Document format
86
+ # ls_dict = []
87
+ # # Iterate over df and add relevant fields to the dict object
88
+ # for index, row in df_docs.iterrows():
89
+ # # Create a Document object for each row
90
+ # doc = Document(
91
+ # row['content'],
92
+ # meta={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'score': row['score']}
93
+ # )
94
+
95
+ # # Append the Document object to the documents list
96
+ # ls_dict.append(doc)
97
+ # return(ls_dict)
98
+
99
+ def get_docs(input_query, country = [], vulnerability_cat = []):
100
+ if not country:
101
+ country = "All Countries"
102
+ if not vulnerability_cat:
103
+ if country == "All Countries":
104
+ filters = None
105
  else:
106
+ filters = {'country': {'$in': country}}
107
+ else:
108
+ if country == "All Countries":
109
+ filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
110
+ else:
111
+ filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
112
+ docs = retriever.retrieve(query=query, filters = filters, top_k = 10)
113
+ # Break out the key fields and convert to pandas for filtering
114
+ docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
115
+ df_docs = pd.DataFrame(docs)
116
+ # Get ourselves an index setup from which to base the source reference number from (in the prompt and matching afterwards)
117
+ df_docs = df_docs.reset_index()
118
+ df_docs['ref_id'] = df_docs.index + 1 # start the index at 1
119
+ # Convert back to Document format
120
+ ls_dict = []
121
+ # Iterate over df and add relevant fields to the dict object
122
+ for index, row in df_docs.iterrows():
123
+ # Create a Document object for each row
124
+ doc = Document(
125
+ row['content'],
126
+ meta={'country': row['country'],'document': row['document'], 'page': row['page'], 'file_name': row['file_name'], 'ref_id': row['ref_id'], 'vulnerability_cat': row['vulnerability_cat'], 'score': row['score']}
127
+ )
128
+
129
+ # Append the Document object to the documents list
130
+ ls_dict.append(doc)
131
+ return ls_dict
132
 
133
  def get_refs(docs, res):
134
  '''
 
154
  return result_str
155
 
156
  # define a special function for putting the prompt together (as we can't use haystack)
157
+ def get_prompt(docs, query):
158
+ base_prompt=prompt_template
 
 
 
 
 
 
159
  # Add the meta data for references
160
+ context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
161
  prompt = base_prompt+"; Context: "+context+"; Question: "+query+"; Answer:"
162
  return(prompt)
163
 
164
  def run_query(input_text, country, model_sel):
165
+ # docs = get_docs(input_text, country)
166
+ docs = get_docs(query, country=country,vulnerability_cat=vulnerability_options)
167
  # st.write('Selected country: ', country) # Debugging country
168
  if model_sel == "chatGPT":
169
+ # res = pipe.run(query=input_text, documents=docs)
170
+ res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, query=query)}])
171
  output = res["results"][0]
172
+ references = get_refs(docs, output)
173
  else:
174
  res = client.text_generation(get_prompt_llama2(docs, query=input_text), max_new_tokens=4000, temperature=0.01, model=model)
175
  output = res
 
180
  st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
181
  st.markdown(references, unsafe_allow_html=True)
182
 
183
+ # # Setup retriever, pulling from local faiss datastore
184
+ # retriever = EmbeddingRetriever(
185
+ # document_store=FAISSDocumentStore.load(
186
+ # index_path="./cpv_full_southern_africa_kenya.faiss",
187
+ # config_path="./cpv_full_southern_africa_kenya.json",
188
+ # ),
189
+ # embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
190
+ # model_format="sentence_transformers",
191
+ # progress_bar=False,
192
+ # )
193
+
194
+ # Setup retriever, pulling from pinecone
195
+ doc_file_name="cpv_full_southern_africa"
196
+ document_store = PineconeDocumentStore(api_key=pinecone_key,
197
+ environment="asia-southeast1-gcp-free",
198
+ index=doc_file_name)
199
+
200
  retriever = EmbeddingRetriever(
201
+ document_store=document_store,
202
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
203
+ model_format="sentence_transformers"
 
 
 
 
204
  )
205
 
206
+
207
  with st.sidebar:
208
  # Dropdown selectbox
209
  country = st.sidebar.selectbox('Select a country:', country_options)
 
213
  * *Then be sure to mention the country names of interest in your query*
214
  """
215
  )
216
+ vulnerabilities_cat = st.sidebar.selectbox('Select a vulnerabilities category:', vulnerability_options)
217
  # choice = st.sidebar.radio(label = 'Select the Document',
218
  # help = 'You can upload the document \
219
  # or else you can try a example document',
 
245
  #----Model Select logic-------
246
  if model_sel == "chatGPT":
247
  model_name = "gpt-3.5-turbo"
248
+ # # Initialize the PromptNode
249
+ # pn = PromptNode(model_name_or_path=model_name, default_prompt_template=template, api_key=openai_key, max_length=2000, model_kwargs={"generation_kwargs": {"do_sample": False, "temperature": 0}})
250
 
251
+ # # Initialize the pipeline
252
+ # pipe = Pipeline()
253
+ # pipe.add_node(component=pn, name="prompt_node", inputs=["Query"])
254
  else:
255
  # Currently disabled
256
  model = "meta-llama/Llama-2-70b-chat-hf"