Update app.py
Browse files
app.py
CHANGED
@@ -109,7 +109,7 @@ def get_docs(input_query, country = [], vulnerability_cat = []):
|
|
109 |
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
|
110 |
else:
|
111 |
filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
|
112 |
-
docs = retriever.retrieve(query=
|
113 |
# Break out the key fields and convert to pandas for filtering
|
114 |
docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
|
115 |
df_docs = pd.DataFrame(docs)
|
@@ -154,11 +154,11 @@ def get_refs(docs, res):
|
|
154 |
return result_str
|
155 |
|
156 |
# define a special function for putting the prompt together (as we can't use haystack)
|
157 |
-
def get_prompt(docs,
|
158 |
base_prompt=prompt_template
|
159 |
# Add the meta data for references
|
160 |
context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
|
161 |
-
prompt = base_prompt+"; Context: "+context+"; Question: "+
|
162 |
return(prompt)
|
163 |
|
164 |
def run_query(input_query, country, model_sel):
|
@@ -167,13 +167,13 @@ def run_query(input_query, country, model_sel):
|
|
167 |
# st.write('Selected country: ', country) # Debugging country
|
168 |
if model_sel == "chatGPT":
|
169 |
# res = pipe.run(query=input_text, documents=docs)
|
170 |
-
res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs,
|
171 |
output = res["results"][0]
|
172 |
references = get_refs(docs, output)
|
173 |
-
else:
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
st.write('Response')
|
178 |
st.success(output)
|
179 |
st.write('References')
|
|
|
109 |
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
|
110 |
else:
|
111 |
filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
|
112 |
+
docs = retriever.retrieve(query=input_query, filters = filters, top_k = 10)
|
113 |
# Break out the key fields and convert to pandas for filtering
|
114 |
docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
|
115 |
df_docs = pd.DataFrame(docs)
|
|
|
154 |
return result_str
|
155 |
|
156 |
# define a special function for putting the prompt together (as we can't use haystack)
|
157 |
+
def get_prompt(docs, input_query):
|
158 |
base_prompt=prompt_template
|
159 |
# Add the meta data for references
|
160 |
context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
|
161 |
+
prompt = base_prompt+"; Context: "+context+"; Question: "+input_query+"; Answer:"
|
162 |
return(prompt)
|
163 |
|
164 |
def run_query(input_query, country, model_sel):
|
|
|
167 |
# st.write('Selected country: ', country) # Debugging country
|
168 |
if model_sel == "chatGPT":
|
169 |
# res = pipe.run(query=input_text, documents=docs)
|
170 |
+
res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_query)}])
|
171 |
output = res["results"][0]
|
172 |
references = get_refs(docs, output)
|
173 |
+
# else:
|
174 |
+
# res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
|
175 |
+
# output = res
|
176 |
+
# references = get_refs(docs, res)
|
177 |
st.write('Response')
|
178 |
st.success(output)
|
179 |
st.write('References')
|