Update app.py
Browse files
app.py
CHANGED
@@ -46,11 +46,10 @@ examples = [
|
|
46 |
|
47 |
def get_docs(input_query, country = None):
|
48 |
# Construct a hacky query to focus the retriever on the target country (see notes below)
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
# query = input_query
|
54 |
|
55 |
# Get top 150 because we want to make sure we have 10 pertaining to the selected country
|
56 |
# TEMP SOLUTION: not ideal, but FAISS document store doesnt allow metadata filtering. Needs to be tested with the full dataset
|
@@ -98,7 +97,7 @@ def get_refs(res):
|
|
98 |
result_str += "Ref. " + str(ref_id) + " [" + doc['meta']['country'] + " " + doc['meta']['document_name'] + "]: " + "*'" + doc['content'] + "'*<br> <br>" # Add <br> for a line break
|
99 |
return result_str
|
100 |
|
101 |
-
def run_query(input_text):
|
102 |
docs = get_docs(input_text)
|
103 |
st.write('Selected country: ', country) # Debugging country
|
104 |
res = pipe.run(query=input_text, documents=docs)
|
@@ -170,4 +169,4 @@ else:
|
|
170 |
|
171 |
|
172 |
if st.button('Submit'):
|
173 |
-
run_query(text)
|
|
|
46 |
|
47 |
def get_docs(input_query, country = None):
|
48 |
# Construct a hacky query to focus the retriever on the target country (see notes below)
|
49 |
+
if country:
|
50 |
+
query = "For the country of "+country+", "+input_query
|
51 |
+
else:
|
52 |
+
query = input_query
|
|
|
53 |
|
54 |
# Get top 150 because we want to make sure we have 10 pertaining to the selected country
|
55 |
# TEMP SOLUTION: not ideal, but FAISS document store doesnt allow metadata filtering. Needs to be tested with the full dataset
|
|
|
97 |
result_str += "Ref. " + str(ref_id) + " [" + doc['meta']['country'] + " " + doc['meta']['document_name'] + "]: " + "*'" + doc['content'] + "'*<br> <br>" # Add <br> for a line break
|
98 |
return result_str
|
99 |
|
100 |
+
def run_query(input_text, country):
|
101 |
docs = get_docs(input_text)
|
102 |
st.write('Selected country: ', country) # Debugging country
|
103 |
res = pipe.run(query=input_text, documents=docs)
|
|
|
169 |
|
170 |
|
171 |
if st.button('Submit'):
|
172 |
+
run_query(text, country)
|