mtyrrell commited on
Commit
8e99f61
1 Parent(s): f5548b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -79,21 +79,19 @@ def get_docs(input_query, country = None):
79
  ls_dict.append(doc)
80
  return(ls_dict)
81
 
82
- def get_refs(res):
83
  '''
84
  Parse response for engineered reference ids (refer to prompt template)
85
  Extract documents using reference ids
86
  '''
87
- import re
88
- text = res["results"][0]
89
- # This pattern should be returned by gpt3.5
90
  # pattern = r'ref\. (\d+)\]\.'
91
  pattern = r'ref\. (\d+)'
92
- ref_ids = [int(match) for match in re.findall(pattern, text)]
93
  # extract
94
  result_str = "" # Initialize an empty string to store the result
95
  for i in range(len(res['documents'])):
96
- doc = res['documents'][i].to_dict()
97
  ref_id = doc['meta']['ref_id']
98
  if ref_id in ref_ids:
99
  result_str += "**Ref. " + str(ref_id) + " [" + doc['meta']['country'] + " " + doc['meta']['document_name'] + "]:** " + "*'" + doc['content'] + "'*<br> <br>" # Add <br> for a line break
@@ -106,7 +104,7 @@ def run_query(input_text, country):
106
  output = res["results"][0]
107
  st.write('Response')
108
  st.success(output)
109
- references = get_refs(res)
110
  st.write('References')
111
  st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
112
  st.markdown(references, unsafe_allow_html=True)
 
79
  ls_dict.append(doc)
80
  return(ls_dict)
81
 
82
+ def get_refs(docs, res):
83
  '''
84
  Parse response for engineered reference ids (refer to prompt template)
85
  Extract documents using reference ids
86
  '''
87
+ # This pattern should be returned by gpt3.5 & llama2
 
 
88
  # pattern = r'ref\. (\d+)\]\.'
89
  pattern = r'ref\. (\d+)'
90
+ ref_ids = [int(match) for match in re.findall(pattern, res)]
91
  # extract
92
  result_str = "" # Initialize an empty string to store the result
93
  for i in range(len(res['documents'])):
94
+ doc = docs[i].to_dict()
95
  ref_id = doc['meta']['ref_id']
96
  if ref_id in ref_ids:
97
  result_str += "**Ref. " + str(ref_id) + " [" + doc['meta']['country'] + " " + doc['meta']['document_name'] + "]:** " + "*'" + doc['content'] + "'*<br> <br>" # Add <br> for a line break
 
104
  output = res["results"][0]
105
  st.write('Response')
106
  st.success(output)
107
+ references = get_refs(docs, res["results"][0])
108
  st.write('References')
109
  st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
110
  st.markdown(references, unsafe_allow_html=True)