notSoNLPnerd commited on
Commit
e09fe1d
·
1 Parent(s): 3842297

final tiny changes

Browse files
Files changed (2) hide show
  1. app.py +11 -10
  2. backend_utils.py +3 -12
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
3
- get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES)
 
4
 
5
  st.set_page_config(
6
  page_title="Retrieval Augmentation with Haystack",
@@ -51,39 +52,39 @@ st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieva
51
  # QUERIES,
52
  # key='q_drop_down', on_change=set_question)
53
 
54
- st.markdown("<h5> Answer with GPT's Internal Knowledge </h5>", unsafe_allow_html=True)
55
  placeholder_plain_gpt = st.empty()
56
  st.text(" ")
57
  st.text(" ")
58
  if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
59
- st.markdown("<h5> Answer with Retrieval Augmented GPT (Static news dataset) </h5>", unsafe_allow_html=True)
60
  else:
61
- st.markdown("<h5> Answer with Retrieval Augmented GPT (Web Search) </h5>", unsafe_allow_html=True)
62
  placeholder_retrieval_augmented = st.empty()
63
 
64
  if st.session_state.get('query') and run_pressed:
65
  ip = st.session_state['query']
66
  with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
67
  p1 = get_plain_pipeline()
68
- with st.spinner('Fetching answers from GPT\'s internal knowledge... '
69
  '\n This may take a few mins and might also fail if OpenAI API server is down.'):
70
  answers = p1.run(ip)
71
  placeholder_plain_gpt.markdown(answers['results'][0])
72
 
73
  if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
74
  with st.spinner(
75
- 'Loading Retrieval Augmented pipeline... \
76
- n This may take a few mins and might also fail if OpenAI API server is down.'):
77
  p2 = get_retrieval_augmented_pipeline()
78
- with st.spinner('Fetching relevant documents from documented stores and calculating answers... '
79
  '\n This may take a few mins and might also fail if OpenAI API server is down.'):
80
  answers_2 = p2.run(ip)
81
  else:
82
  with st.spinner(
83
- 'Loading Retrieval Augmented pipeline... \
84
  n This may take a few mins and might also fail if OpenAI API server is down.'):
85
  p3 = get_web_retrieval_augmented_pipeline()
86
- with st.spinner('Fetching relevant documents from the Web and calculating answers... '
87
  '\n This may take a few mins and might also fail if OpenAI API server is down.'):
88
  answers_2 = p3.run(ip)
89
  placeholder_retrieval_augmented.markdown(answers_2['results'][0])
 
1
  import streamlit as st
2
  from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
3
+ get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES,
4
+ PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS)
5
 
6
  st.set_page_config(
7
  page_title="Retrieval Augmentation with Haystack",
 
52
  # QUERIES,
53
  # key='q_drop_down', on_change=set_question)
54
 
55
+ st.markdown(f"<h5> {PLAIN_GPT_ANS} </h5>", unsafe_allow_html=True)
56
  placeholder_plain_gpt = st.empty()
57
  st.text(" ")
58
  st.text(" ")
59
  if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
60
+ st.markdown(f"<h5> {GPT_LOCAL_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
61
  else:
62
+ st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS} </h5>", unsafe_allow_html=True)
63
  placeholder_retrieval_augmented = st.empty()
64
 
65
  if st.session_state.get('query') and run_pressed:
66
  ip = st.session_state['query']
67
  with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
68
  p1 = get_plain_pipeline()
69
+ with st.spinner('Fetching answers from plain GPT... '
70
  '\n This may take a few mins and might also fail if OpenAI API server is down.'):
71
  answers = p1.run(ip)
72
  placeholder_plain_gpt.markdown(answers['results'][0])
73
 
74
  if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
75
  with st.spinner(
76
+ 'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... '
77
+ '\n This may take a few mins and might also fail if OpenAI API server is down.'):
78
  p2 = get_retrieval_augmented_pipeline()
79
+ with st.spinner('Getting relevant documents from documented stores and calculating answers... '
80
  '\n This may take a few mins and might also fail if OpenAI API server is down.'):
81
  answers_2 = p2.run(ip)
82
  else:
83
  with st.spinner(
84
+ 'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \
85
  n This may take a few mins and might also fail if OpenAI API server is down.'):
86
  p3 = get_web_retrieval_augmented_pipeline()
87
+ with st.spinner('Getting relevant documents from the Web and calculating answers... '
88
  '\n This may take a few mins and might also fail if OpenAI API server is down.'):
89
  answers_2 = p3.run(ip)
90
  placeholder_retrieval_augmented.markdown(answers_2['results'][0])
backend_utils.py CHANGED
@@ -12,6 +12,9 @@ QUERIES = [
12
  "Who is responsible for SVC collapse?",
13
  "When did SVB collapse?"
14
  ]
 
 
 
15
 
16
 
17
  @st.cache_resource(show_spinner=False)
@@ -76,18 +79,6 @@ def get_web_retrieval_augmented_pipeline():
76
  return pipeline
77
 
78
 
79
- # @st.cache_resource(show_spinner=False)
80
- # def app_init():
81
- # print("Loading Pipelines...")
82
- # p1 = get_plain_pipeline()
83
- # print("Loaded Plain Pipeline")
84
- # p2 = get_retrieval_augmented_pipeline()
85
- # print("Loaded Retrieval Augmented Pipeline")
86
- # p3 = get_web_retrieval_augmented_pipeline()
87
- # print("Loaded Web Retrieval Augmented Pipeline")
88
- # return p1, p2, p3
89
-
90
-
91
  if 'query' not in st.session_state:
92
  st.session_state['query'] = ""
93
 
 
12
  "Who is responsible for SVC collapse?",
13
  "When did SVB collapse?"
14
  ]
15
+ PLAIN_GPT_ANS = "Answer with plain GPT"
16
+ GPT_LOCAL_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Static news dataset)"
17
+ GPT_WEB_RET_AUG_ANS = "Answer with Retrieval Augmented GPT (Web Search)"
18
 
19
 
20
  @st.cache_resource(show_spinner=False)
 
79
  return pipeline
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if 'query' not in st.session_state:
83
  st.session_state['query'] = ""
84