notSoNLPnerd commited on
Commit
4a448eb
·
1 Parent(s): 65935d6

Working all

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [theme]
2
+ base = "light"
3
+ font="monospace"
app.py CHANGED
@@ -1,116 +1,59 @@
1
- import glob
2
- import os
3
- import logging
4
- import sys
5
-
6
  import streamlit as st
7
- from haystack import Pipeline
8
- from haystack.document_stores import FAISSDocumentStore
9
- from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
10
- from haystack.nodes.retriever.web import WebRetriever
11
- from haystack.schema import Document
12
-
13
- logging.basicConfig(
14
- level=logging.DEBUG,
15
- format="%(levelname)s %(asctime)s %(name)s:%(message)s",
16
- handlers=[logging.StreamHandler(sys.stdout)],
17
- force=True,
18
- )
19
-
20
- def get_plain_pipeline():
21
- prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
22
-
23
- # Now let make one PromptNode use the default model and the other one the OpenAI model:
24
- plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
25
- node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
26
-
27
- pipeline = Pipeline()
28
- pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
29
- return pipeline
30
-
31
-
32
- def get_ret_aug_pipeline():
33
- ds = FAISSDocumentStore(faiss_index_path="my_faiss_index.faiss",
34
- faiss_config_path="my_faiss_index.json")
35
-
36
- retriever = EmbeddingRetriever(
37
- document_store=ds,
38
- embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
39
- model_format="sentence_transformers",
40
- top_k=2
41
- )
42
- shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
43
-
44
- default_template= PromptTemplate(
45
- name="question-answering",
46
- prompt_text="Given the context please answer the question. Context: $documents; Question: "
47
- "$query; Answer:",
48
- )
49
- # Let's initiate the PromptNode
50
- node = PromptNode("text-davinci-003", default_prompt_template=default_template,
51
- api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
52
-
53
- # Let's create a pipeline with Shaper and PromptNode
54
- pipe = Pipeline()
55
- pipe.add_node(component=retriever, name='retriever', inputs=['Query'])
56
- pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
57
- pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
58
- return pipe
59
-
60
-
61
- def get_web_ret_pipeline():
62
- search_key = st.secrets["WEBRET_API_KEY"]
63
- web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
64
- shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
65
- default_template = PromptTemplate(
66
- name="question-answering",
67
- prompt_text="Given the context please answer the question. Context: $documents; Question: "
68
- "$query; Answer:",
69
- )
70
- # Let's initiate the PromptNode
71
- node = PromptNode("text-davinci-003", default_prompt_template=default_template,
72
- api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
73
- # Let's create a pipeline with Shaper and PromptNode
74
- pipe = Pipeline()
75
- pipe.add_node(component=web_retriever, name='retriever', inputs=['Query'])
76
- pipe.add_node(component=shaper, name="shaper", inputs=["retriever"])
77
- pipe.add_node(component=node, name="prompt_node", inputs=["shaper"])
78
- return pipe
79
-
80
- def app_init():
81
- os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
82
- p1 = get_plain_pipeline()
83
- p2 = get_ret_aug_pipeline()
84
- p3 = get_web_ret_pipeline()
85
- return p1, p2, p3
86
-
87
-
88
- def main():
89
  p1, p2, p3 = app_init()
90
- st.title("Haystack Demo")
91
- input = st.text_input("Query ...", "Did SVB collapse?")
92
-
93
- query_type = st.radio("Type",
94
- ("Retrieval Augmented", "Retrieval Augmented with Web Search"))
95
- # col_1, col_2 = st.columns(2)
96
-
97
- if st.button("Random Question"):
98
- new_text = "Streamlit is great!"
99
- input.value = new_text
100
-
101
- # with col_1:
102
- # st.text("PLAIN")
103
  answers = p1.run(input)
104
- st.text_area("PLAIN GPT", answers['results'][0])
105
 
106
- # with col_2:
107
- # st.write(query_type.upper())
108
- if query_type == "Retrieval Augmented":
109
  answers_2 = p2.run(input)
110
  else:
111
  answers_2 = p3.run(input)
112
- st.text_area(query_type.upper(),answers_2['results'][0])
113
-
114
-
115
- if __name__ == "__main__":
116
- main()
 
 
 
 
 
 
1
  import streamlit as st
2
+ from backend_utils import app_init, set_q1, set_q2, set_q3, set_q4, set_q5
3
+
4
+ st.markdown("<center> <h1> Haystack Demo </h1> </center>", unsafe_allow_html=True)
5
+
6
+ if st.session_state.get('pipelines_loaded', False):
7
+ with st.spinner('Loading pipelines...'):
8
+ p1, p2, p3 = app_init()
9
+ st.success('Pipelines are loaded', icon="✅")
10
+ st.session_state['pipelines_loaded'] = True
11
+
12
+ placeholder = st.empty()
13
+ with placeholder:
14
+ search_bar, button = st.columns([3, 1])
15
+ with search_bar:
16
+ username = st.text_area(f"", max_chars=200, key='query')
17
+
18
+ with button:
19
+ st.write("")
20
+ st.write("")
21
+ run_pressed = st.button("Run")
22
+
23
+ st.radio("Type", ("Retrieval Augmented", "Retrieval Augmented with Web Search"), key="query_type")
24
+
25
+ # st.sidebar.selectbox(
26
+ # "Example Questions:",
27
+ # QUERIES,
28
+ # key='q_drop_down', on_change=set_question)
29
+
30
+ c1, c2, c3, c4, c5 = st.columns(5)
31
+ with c1:
32
+ st.button('Example Q1', on_click=set_q1)
33
+ with c2:
34
+ st.button('Example Q2', on_click=set_q2)
35
+ with c3:
36
+ st.button('Example Q3', on_click=set_q3)
37
+ with c4:
38
+ st.button('Example Q4', on_click=set_q4)
39
+ with c5:
40
+ st.button('Example Q5', on_click=set_q5)
41
+
42
+ st.markdown("<h4> Answer with PLAIN GPT </h4>", unsafe_allow_html=True)
43
+ placeholder_plain_gpt = st.empty()
44
+ st.text("")
45
+ st.text("")
46
+ st.markdown(f"<h4> Answer with {st.session_state['query_type'].upper()} </h4>", unsafe_allow_html=True)
47
+ placeholder_retrieval_augmented = st.empty()
48
+
49
+ if st.session_state.get('query') and run_pressed:
50
+ input = st.session_state['query']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  p1, p2, p3 = app_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  answers = p1.run(input)
53
+ placeholder_plain_gpt.markdown(answers['results'][0])
54
 
55
+ if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
 
 
56
  answers_2 = p2.run(input)
57
  else:
58
  answers_2 = p3.run(input)
59
+ placeholder_retrieval_augmented.markdown(answers_2['results'][0])
 
 
 
 
backend_utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ from haystack import Pipeline
5
+ from haystack.document_stores import FAISSDocumentStore
6
+ from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
7
+ from haystack.nodes.retriever.web import WebRetriever
8
+
9
+
10
+ QUERIES = [
11
+ "Did SVB collapse?",
12
+ "Why did SVB collapse?",
13
+ "What does SVB failure mean for our economy?",
14
+ "Who is responsible for SVC collapse?",
15
+ "When did SVB collapse?"
16
+ ]
17
+
18
+ def ChangeWidgetFontSize(wgt_txt, wch_font_size = '12px'):
19
+ htmlstr = """<script>var elements = window.parent.document.querySelectorAll('*'), i;
20
+ for (i = 0; i < elements.length; ++i) { if (elements[i].innerText == |wgt_txt|)
21
+ { elements[i].style.fontSize='""" + wch_font_size + """';} } </script> """
22
+
23
+ htmlstr = htmlstr.replace('|wgt_txt|', "'" + wgt_txt + "'")
24
+
25
+
26
+ def get_plain_pipeline():
27
+ prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
28
+ # Now let make one PromptNode use the default model and the other one the OpenAI model:
29
+ plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
30
+ node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
31
+ pipeline = Pipeline()
32
+ pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
33
+ return pipeline
34
+
35
+
36
+ def get_retrieval_augmented_pipeline():
37
+ ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
38
+ faiss_config_path="data/my_faiss_index.json")
39
+
40
+ retriever = EmbeddingRetriever(
41
+ document_store=ds,
42
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
43
+ model_format="sentence_transformers",
44
+ top_k=2
45
+ )
46
+ shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
47
+
48
+ default_template = PromptTemplate(
49
+ name="question-answering",
50
+ prompt_text="Given the context please answer the question. Context: $documents; Question: "
51
+ "$query; Answer:",
52
+ )
53
+ # Let's initiate the PromptNode
54
+ node = PromptNode("text-davinci-003", default_prompt_template=default_template,
55
+ api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
56
+
57
+ # Let's create a pipeline with Shaper and PromptNode
58
+ pipeline = Pipeline()
59
+ pipeline.add_node(component=retriever, name='retriever', inputs=['Query'])
60
+ pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
61
+ pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
62
+ return pipeline
63
+
64
+
65
+ def get_web_retrieval_augmented_pipeline():
66
+ search_key = st.secrets["WEBRET_API_KEY"]
67
+ web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
68
+ shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
69
+ default_template = PromptTemplate(
70
+ name="question-answering",
71
+ prompt_text="Given the context please answer the question. Context: $documents; Question: "
72
+ "$query; Answer:",
73
+ )
74
+ # Let's initiate the PromptNode
75
+ node = PromptNode("text-davinci-003", default_prompt_template=default_template,
76
+ api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
77
+ # Let's create a pipeline with Shaper and PromptNode
78
+ pipeline = Pipeline()
79
+ pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query'])
80
+ pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
81
+ pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
82
+ return pipeline
83
+
84
+
85
+ @st.cache_resource(show_spinner=False)
86
+ def app_init():
87
+ os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
88
+ p1 = get_plain_pipeline()
89
+ p2 = get_retrieval_augmented_pipeline()
90
+ p3 = get_web_retrieval_augmented_pipeline()
91
+ return p1, p2, p3
92
+
93
+
94
+ if 'query' not in st.session_state:
95
+ st.session_state['query'] = ""
96
+
97
+
98
+ def set_question():
99
+ st.session_state['query'] = st.session_state['q_drop_down']
100
+
101
+
102
+ def set_q1():
103
+ st.session_state['query'] = QUERIES[0]
104
+
105
+
106
+ def set_q2():
107
+ st.session_state['query'] = QUERIES[1]
108
+
109
+
110
+ def set_q3():
111
+ st.session_state['query'] = QUERIES[2]
112
+
113
+
114
+ def set_q4():
115
+ st.session_state['query'] = QUERIES[3]
116
+
117
+
118
+ def set_q5():
119
+ st.session_state['query'] = QUERIES[4]
120
+
my_faiss_index.faiss → data/my_faiss_index.faiss RENAMED
File without changes
my_faiss_index.json → data/my_faiss_index.json RENAMED
File without changes
data/sample_1.txt DELETED
@@ -1 +0,0 @@
1
- Hello World 1!
 
 
data/sample_2.txt DELETED
@@ -1 +0,0 @@
1
- Hello World 2!
 
 
my_faiss_config.json DELETED
@@ -1 +0,0 @@
1
- {"faiss_config_path": "my_faiss_config.json", "embedding_dim": 768}
 
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ git+https://github.com/deepset-ai/haystack.git@ffd02c29f7cc83a119b6440bfbabaacda
2
  faiss-cpu==1.7.2
3
  sqlalchemy>=1.4.2,<2
4
  sqlalchemy_utils
5
- psycopg2-binary
 
 
2
  faiss-cpu==1.7.2
3
  sqlalchemy>=1.4.2,<2
4
  sqlalchemy_utils
5
+ psycopg2-binary
6
+ streamlit==1.19.0