notSoNLPnerd commited on
Commit
bd2e0e7
1 Parent(s): e09fe1d

Additional minimal UI changes, heavy refactoring

Browse files
Files changed (5) hide show
  1. app.py +12 -52
  2. backend_utils.py +0 -107
  3. utils/__init__.py +0 -0
  4. utils/constants.py +10 -0
  5. utils/ui.py +114 -0
app.py CHANGED
@@ -1,66 +1,26 @@
1
  import streamlit as st
2
- from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
3
- get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES,
4
- PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS)
 
5
 
6
  st.set_page_config(
7
  page_title="Retrieval Augmentation with Haystack",
 
8
  )
 
9
 
10
  st.markdown("<center> <h2> Reduce Hallucinations with Retrieval Augmentation </h2> </center>", unsafe_allow_html=True)
11
 
12
  st.markdown("Ask a question about the collapse of the Silicon Valley Bank (SVB).", unsafe_allow_html=True)
13
 
14
- # if not st.session_state.get('pipelines_loaded', False):
15
- # with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
16
- # p1, p2, p3 = app_init()
17
- # st.success('Pipelines are loaded', icon="✅")
18
- # st.session_state['pipelines_loaded'] = True
19
 
20
- placeholder = st.empty()
21
- with placeholder:
22
- search_bar, button = st.columns([3, 1])
23
- with search_bar:
24
- username = st.text_area(f" ", max_chars=200, key='query')
25
-
26
- with button:
27
- st.write(" ")
28
- st.write(" ")
29
- run_pressed = st.button("Run")
30
-
31
- st.markdown("<center> <h5> Example questions </h5> </center>", unsafe_allow_html=True)
32
-
33
- st.write(" ")
34
- st.write(" ")
35
- c1, c2, c3, c4, c5 = st.columns(5)
36
- with c1:
37
- st.button(QUERIES[0], on_click=set_q1)
38
- with c2:
39
- st.button(QUERIES[1], on_click=set_q2)
40
- with c3:
41
- st.button(QUERIES[2], on_click=set_q3)
42
- with c4:
43
- st.button(QUERIES[3], on_click=set_q4)
44
- with c5:
45
- st.button(QUERIES[4], on_click=set_q5)
46
-
47
- st.write(" ")
48
- st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieval Augmented with Web Search"), key="query_type")
49
-
50
- # st.sidebar.selectbox(
51
- # "Example Questions:",
52
- # QUERIES,
53
- # key='q_drop_down', on_change=set_question)
54
-
55
- st.markdown(f"<h5> {PLAIN_GPT_ANS} </h5>", unsafe_allow_html=True)
56
- placeholder_plain_gpt = st.empty()
57
- st.text(" ")
58
- st.text(" ")
59
- if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
60
- st.markdown(f"<h5> {GPT_LOCAL_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
61
- else:
62
- st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
63
- placeholder_retrieval_augmented = st.empty()
64
 
65
  if st.session_state.get('query') and run_pressed:
66
  ip = st.session_state['query']
 
1
  import streamlit as st
2
+ from utils.backend import (get_plain_pipeline, get_retrieval_augmented_pipeline,
3
+ get_web_retrieval_augmented_pipeline)
4
+ from utils.ui import set_q1, set_q2, set_q3, set_q4, set_q5, left_sidebar, right_sidebar, main_column
5
+ from utils.constants import QUERIES, PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS
6
 
7
  st.set_page_config(
8
  page_title="Retrieval Augmentation with Haystack",
9
+ layout="wide"
10
  )
11
+ left_sidebar()
12
 
13
  st.markdown("<center> <h2> Reduce Hallucinations with Retrieval Augmentation </h2> </center>", unsafe_allow_html=True)
14
 
15
  st.markdown("Ask a question about the collapse of the Silicon Valley Bank (SVB).", unsafe_allow_html=True)
16
 
17
+ col_1, col_2 = st.columns([4, 2], gap="small")
18
+ with col_1:
19
+ run_pressed, placeholder_plain_gpt, placeholder_retrieval_augmented = main_column()
20
+ print(f"Run value: {st.session_state.get('run', 'not found')}")
 
21
 
22
+ with col_2:
23
+ right_sidebar()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  if st.session_state.get('query') and run_pressed:
26
  ip = st.session_state['query']
backend_utils.py DELETED
@@ -1,107 +0,0 @@
1
- import streamlit as st
2
- from haystack import Pipeline
3
- from haystack.document_stores import FAISSDocumentStore
4
- from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
5
- from haystack.nodes.retriever.web import WebRetriever
6
-
7
-
8
- QUERIES = [
9
- "Did SVB collapse?",
10
- "Why did SVB collapse?",
11
- "What does SVB failure mean for our economy?",
12
- "Who is responsible for SVC collapse?",
13
- "When did SVB collapse?"
14
- ]
15
- PLAIN_GPT_ANS = "Answer with plain GPT"
16
- GPT_LOCAL_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Static news dataset)"
17
- GPT_WEB_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Web Search)"
18
-
19
-
20
- @st.cache_resource(show_spinner=False)
21
- def get_plain_pipeline():
22
- prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
23
- # Now let make one PromptNode use the default model and the other one the OpenAI model:
24
- plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
25
- node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
26
- pipeline = Pipeline()
27
- pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
28
- return pipeline
29
-
30
-
31
- @st.cache_resource(show_spinner=False)
32
- def get_retrieval_augmented_pipeline():
33
- ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
34
- faiss_config_path="data/my_faiss_index.json")
35
-
36
- retriever = EmbeddingRetriever(
37
- document_store=ds,
38
- embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
39
- model_format="sentence_transformers",
40
- top_k=2
41
- )
42
- shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
43
-
44
- default_template = PromptTemplate(
45
- name="question-answering",
46
- prompt_text="Given the context please answer the question. Context: $documents; Question: "
47
- "$query; Answer:",
48
- )
49
- # Let's initiate the PromptNode
50
- node = PromptNode("text-davinci-003", default_prompt_template=default_template,
51
- api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
52
-
53
- # Let's create a pipeline with Shaper and PromptNode
54
- pipeline = Pipeline()
55
- pipeline.add_node(component=retriever, name='retriever', inputs=['Query'])
56
- pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
57
- pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
58
- return pipeline
59
-
60
-
61
- @st.cache_resource(show_spinner=False)
62
- def get_web_retrieval_augmented_pipeline():
63
- search_key = st.secrets["WEBRET_API_KEY"]
64
- web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
65
- shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
66
- default_template = PromptTemplate(
67
- name="question-answering",
68
- prompt_text="Given the context please answer the question. Context: $documents; Question: "
69
- "$query; Answer:",
70
- )
71
- # Let's initiate the PromptNode
72
- node = PromptNode("text-davinci-003", default_prompt_template=default_template,
73
- api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
74
- # Let's create a pipeline with Shaper and PromptNode
75
- pipeline = Pipeline()
76
- pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query'])
77
- pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
78
- pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
79
- return pipeline
80
-
81
-
82
- if 'query' not in st.session_state:
83
- st.session_state['query'] = ""
84
-
85
-
86
- def set_question():
87
- st.session_state['query'] = st.session_state['q_drop_down']
88
-
89
-
90
- def set_q1():
91
- st.session_state['query'] = QUERIES[0]
92
-
93
-
94
- def set_q2():
95
- st.session_state['query'] = QUERIES[1]
96
-
97
-
98
- def set_q3():
99
- st.session_state['query'] = QUERIES[2]
100
-
101
-
102
- def set_q4():
103
- st.session_state['query'] = QUERIES[3]
104
-
105
-
106
- def set_q5():
107
- st.session_state['query'] = QUERIES[4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py ADDED
File without changes
utils/constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ QUERIES = [
2
+ "Did SVB collapse?",
3
+ "Why did SVB collapse?",
4
+ "What does SVB failure mean for our economy?",
5
+ "Who is responsible for SVC collapse?",
6
+ "When did SVB collapse?"
7
+ ]
8
+ PLAIN_GPT_ANS = "Answer with plain GPT"
9
+ GPT_LOCAL_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Static news dataset)"
10
+ GPT_WEB_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Web Search)"
utils/ui.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+
4
+ from .constants import QUERIES, PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS
5
+
6
+
7
+ def set_question():
8
+ st.session_state['query'] = st.session_state['q_drop_down']
9
+
10
+
11
+ def set_q1():
12
+ st.session_state['query'] = QUERIES[0]
13
+
14
+
15
+ def set_q2():
16
+ st.session_state['query'] = QUERIES[1]
17
+
18
+
19
+ def set_q3():
20
+ st.session_state['query'] = QUERIES[2]
21
+
22
+
23
+ def set_q4():
24
+ st.session_state['query'] = QUERIES[3]
25
+
26
+
27
+ def set_q5():
28
+ st.session_state['query'] = QUERIES[4]
29
+
30
+ def main_column():
31
+ placeholder = st.empty()
32
+ with placeholder:
33
+ search_bar, button = st.columns([3, 1])
34
+ with search_bar:
35
+ username = st.text_area(f" ", max_chars=200, key='query')
36
+
37
+ with button:
38
+ st.write(" ")
39
+ st.write(" ")
40
+ run_pressed = st.button("Run", key="run")
41
+
42
+ st.write(" ")
43
+ st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieval Augmented with Web Search"), key="query_type")
44
+
45
+ # st.sidebar.selectbox(
46
+ # "Example Questions:",
47
+ # QUERIES,
48
+ # key='q_drop_down', on_change=set_question)
49
+
50
+ st.markdown(f"<h5> {PLAIN_GPT_ANS} </h5>", unsafe_allow_html=True)
51
+ placeholder_plain_gpt = st.empty()
52
+ st.text(" ")
53
+ st.text(" ")
54
+ if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
55
+ st.markdown(f"<h5> {GPT_LOCAL_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
56
+ else:
57
+ st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
58
+ placeholder_retrieval_augmented = st.empty()
59
+
60
+ return run_pressed, placeholder_plain_gpt, placeholder_retrieval_augmented
61
+
62
+
63
+ def right_sidebar():
64
+ st.markdown("<h5> Example questions </h5>", unsafe_allow_html=True)
65
+ # c1, c2, c3, c4, c5 = st.columns(5)
66
+ # with c1:
67
+ st.button(QUERIES[0], on_click=set_q1)
68
+ # with c2:
69
+ st.button(QUERIES[1], on_click=set_q2)
70
+ # with c3:
71
+ st.button(QUERIES[2], on_click=set_q3)
72
+ # with c4:
73
+ st.button(QUERIES[3], on_click=set_q4)
74
+ # with c5:
75
+ st.button(QUERIES[4], on_click=set_q5)
76
+
77
+
78
+ def left_sidebar():
79
+ with st.sidebar:
80
+ image = Image.open('logo/haystack-logo-colored.png')
81
+ st.markdown("Thanks for coming to this 🤗 Space.\n\n"
82
+ "This is an effort towards showcasing how can you use Haystack for Retrieval Augmented QA, "
83
+ "with local document store as well as WebRetriever (coming soon!) \n\n"
84
+ "For more on how this was built, instructions along with a Repository "
85
+ "will be published soon and updated here.")
86
+
87
+ # st.markdown(
88
+ # "## How to use\n"
89
+ # "1. Enter your [OpenAI API key](https://platform.openai.com/account/api-keys) below\n"
90
+ # "2. Enter a Serper Dev API key\n"
91
+ # "3. Enjoy 🤗\n"
92
+ # )
93
+
94
+ # api_key_input = st.text_input(
95
+ # "OpenAI API Key",
96
+ # type="password",
97
+ # placeholder="Paste your OpenAI API key here (sk-...)",
98
+ # help="You can get your API key from https://platform.openai.com/account/api-keys.",
99
+ # value=st.session_state.get("OPENAI_API_KEY", ""),
100
+ # )
101
+
102
+ # if api_key_input:
103
+ # set_openai_api_key(api_key_input)
104
+
105
+ st.markdown("---")
106
+ st.markdown(
107
+ "## How this works\n"
108
+ "This app was built with [Haystack](https://haystack.deepset.ai) using the"
109
+ " [`PromptNode`](https://docs.haystack.deepset.ai/docs/prompt_node) and [`Retriever`](https://docs.haystack.deepset.ai/docs/retriever#embedding-retrieval-recommended).\n\n"
110
+ " You can find the source code in **Files and versions** tab."
111
+ )
112
+
113
+ st.markdown("---")
114
+ st.image(image, width=250)