Spaces:
Runtime error
Runtime error
notSoNLPnerd
commited on
Commit
•
e09fe1d
1
Parent(s):
3842297
final tiny changes
Browse files- app.py +11 -10
- backend_utils.py +3 -12
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
|
3 |
-
get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES
|
|
|
4 |
|
5 |
st.set_page_config(
|
6 |
page_title="Retrieval Augmentation with Haystack",
|
@@ -51,39 +52,39 @@ st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieva
|
|
51 |
# QUERIES,
|
52 |
# key='q_drop_down', on_change=set_question)
|
53 |
|
54 |
-
st.markdown("<h5>
|
55 |
placeholder_plain_gpt = st.empty()
|
56 |
st.text(" ")
|
57 |
st.text(" ")
|
58 |
if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
|
59 |
-
st.markdown("<h5>
|
60 |
else:
|
61 |
-
st.markdown("<h5>
|
62 |
placeholder_retrieval_augmented = st.empty()
|
63 |
|
64 |
if st.session_state.get('query') and run_pressed:
|
65 |
ip = st.session_state['query']
|
66 |
with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
|
67 |
p1 = get_plain_pipeline()
|
68 |
-
with st.spinner('Fetching answers from GPT
|
69 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
70 |
answers = p1.run(ip)
|
71 |
placeholder_plain_gpt.markdown(answers['results'][0])
|
72 |
|
73 |
if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
|
74 |
with st.spinner(
|
75 |
-
'Loading Retrieval Augmented pipeline...
|
76 |
-
n This may take a few mins and might also fail if OpenAI API server is down.'):
|
77 |
p2 = get_retrieval_augmented_pipeline()
|
78 |
-
with st.spinner('
|
79 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
80 |
answers_2 = p2.run(ip)
|
81 |
else:
|
82 |
with st.spinner(
|
83 |
-
'Loading Retrieval Augmented pipeline... \
|
84 |
n This may take a few mins and might also fail if OpenAI API server is down.'):
|
85 |
p3 = get_web_retrieval_augmented_pipeline()
|
86 |
-
with st.spinner('
|
87 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
88 |
answers_2 = p3.run(ip)
|
89 |
placeholder_retrieval_augmented.markdown(answers_2['results'][0])
|
|
|
1 |
import streamlit as st
|
2 |
from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
|
3 |
+
get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES,
|
4 |
+
PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS)
|
5 |
|
6 |
st.set_page_config(
|
7 |
page_title="Retrieval Augmentation with Haystack",
|
|
|
52 |
# QUERIES,
|
53 |
# key='q_drop_down', on_change=set_question)
|
54 |
|
55 |
+
st.markdown(f"<h5> {PLAIN_GPT_ANS} </h5>", unsafe_allow_html=True)
|
56 |
placeholder_plain_gpt = st.empty()
|
57 |
st.text(" ")
|
58 |
st.text(" ")
|
59 |
if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
|
60 |
+
st.markdown(f"<h5> {GPT_LOCAL_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
|
61 |
else:
|
62 |
+
st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
|
63 |
placeholder_retrieval_augmented = st.empty()
|
64 |
|
65 |
if st.session_state.get('query') and run_pressed:
|
66 |
ip = st.session_state['query']
|
67 |
with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
|
68 |
p1 = get_plain_pipeline()
|
69 |
+
with st.spinner('Fetching answers from plain GPT... '
|
70 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
71 |
answers = p1.run(ip)
|
72 |
placeholder_plain_gpt.markdown(answers['results'][0])
|
73 |
|
74 |
if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
|
75 |
with st.spinner(
|
76 |
+
'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... '
|
77 |
+
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
78 |
p2 = get_retrieval_augmented_pipeline()
|
79 |
+
with st.spinner('Getting relevant documents from documented stores and calculating answers... '
|
80 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
81 |
answers_2 = p2.run(ip)
|
82 |
else:
|
83 |
with st.spinner(
|
84 |
+
'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \
|
85 |
n This may take a few mins and might also fail if OpenAI API server is down.'):
|
86 |
p3 = get_web_retrieval_augmented_pipeline()
|
87 |
+
with st.spinner('Getting relevant documents from the Web and calculating answers... '
|
88 |
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
89 |
answers_2 = p3.run(ip)
|
90 |
placeholder_retrieval_augmented.markdown(answers_2['results'][0])
|
backend_utils.py
CHANGED
@@ -12,6 +12,9 @@ QUERIES = [
|
|
12 |
"Who is responsible for SVC collapse?",
|
13 |
"When did SVB collapse?"
|
14 |
]
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
@st.cache_resource(show_spinner=False)
|
@@ -76,18 +79,6 @@ def get_web_retrieval_augmented_pipeline():
|
|
76 |
return pipeline
|
77 |
|
78 |
|
79 |
-
# @st.cache_resource(show_spinner=False)
|
80 |
-
# def app_init():
|
81 |
-
# print("Loading Pipelines...")
|
82 |
-
# p1 = get_plain_pipeline()
|
83 |
-
# print("Loaded Plain Pipeline")
|
84 |
-
# p2 = get_retrieval_augmented_pipeline()
|
85 |
-
# print("Loaded Retrieval Augmented Pipeline")
|
86 |
-
# p3 = get_web_retrieval_augmented_pipeline()
|
87 |
-
# print("Loaded Web Retrieval Augmented Pipeline")
|
88 |
-
# return p1, p2, p3
|
89 |
-
|
90 |
-
|
91 |
if 'query' not in st.session_state:
|
92 |
st.session_state['query'] = ""
|
93 |
|
|
|
12 |
"Who is responsible for SVC collapse?",
|
13 |
"When did SVB collapse?"
|
14 |
]
|
15 |
+
PLAIN_GPT_ANS = "Answer with plain GPT"
|
16 |
+
GPT_LOCAL_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Static news dataset)"
|
17 |
+
GPT_WEB_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Web Search)"
|
18 |
|
19 |
|
20 |
@st.cache_resource(show_spinner=False)
|
|
|
79 |
return pipeline
|
80 |
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
if 'query' not in st.session_state:
|
83 |
st.session_state['query'] = ""
|
84 |
|