daniel Foley commited on
Commit
08c6b0b
Β·
1 Parent(s): 9b667a3

test hf concurrence

Browse files
Files changed (1) hide show
  1. streamlit-rag-app.py +127 -38
streamlit-rag-app.py CHANGED
@@ -1,96 +1,185 @@
1
  import streamlit as st
 
2
  import os
 
3
  import json
 
4
  from dotenv import load_dotenv
5
 
6
- from langchain.chains import RetrievalQA
 
 
 
7
  from langchain_community.vectorstores import FAISS
 
8
  from langchain.text_splitter import CharacterTextSplitter
9
- from langchain.chat_models import ChatOpenAI
 
 
10
  from langchain.schema import Document
11
- from langchain.embeddings import HuggingFaceEmbeddings
 
 
 
 
 
 
 
 
 
12
 
13
  # Load environment variables
 
14
  load_dotenv()
15
 
 
 
16
  # Get the OpenAI API key from the environment
 
17
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
18
  if not OPENAI_API_KEY:
 
19
  st.error("OPENAI_API_KEY is not set. Please add it to your .env file.")
20
 
 
 
21
  # Initialize session state variables
 
22
  if 'vector_store' not in st.session_state:
 
23
  st.session_state.vector_store = None
24
- if 'qa_chain' not in st.session_state:
25
- st.session_state.qa_chain = None
26
-
27
- def load_json_file(file_path):
28
- """Load JSON data from a file."""
29
- with open(file_path, "r", encoding="utf-8") as file:
30
- data = json.load(file)
31
- return data
32
-
33
- def setup_vector_store_from_json(json_data):
34
- """Create a vector store from JSON data."""
35
- documents = [Document(page_content=item["content"], metadata={"url": item["url"]}) for item in json_data]
36
 
37
- # Use HuggingFace embeddings
38
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
39
 
40
- vector_store = FAISS.from_documents(documents, embeddings)
41
- return vector_store
42
 
43
- def setup_qa_chain(vector_store):
44
- """Set up the QA chain with a retriever."""
45
- retriever = vector_store.as_retriever(search_kwargs={"k": 3})
46
- llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY)
47
- qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True)
48
- return qa_chain
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def main():
 
51
  # Set page title and header
52
- st.set_page_config(page_title="LibRAG, page_icon="πŸ“–")
 
 
 
 
53
  st.title("Boston Public Library Database πŸ“š")
54
 
 
 
 
 
 
 
55
  # Sidebar for initialization
56
- st.sidebar.header("Initialize Knowledge Base")
57
- if st.sidebar.button("Load Data"):
58
- try:
59
- # Load and preprocess the JSON file
60
- json_data = load_json_file(".json")
61
- st.session_state.vector_store = setup_vector_store_from_json(json_data)
62
- st.session_state.qa_chain = setup_qa_chain(st.session_state.vector_store)
63
- st.sidebar.success("Knowledge base loaded successfully!")
64
- except Exception as e:
65
- st.sidebar.error(f"Error loading data: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Query input and processing
 
68
  st.header("Ask a Question")
 
69
  query = st.text_input("Enter your question about BPL's database")
70
 
 
 
71
  if query:
 
72
  # Check if vector store and QA chain are initialized
73
- if st.session_state.qa_chain is None:
 
 
74
  st.warning("Please load the knowledge base first using the sidebar.")
 
75
  else:
 
76
  # Run the query
 
77
  try:
78
- response = st.session_state.qa_chain({"query": query})
 
 
79
 
 
80
  # Display answer
 
81
  st.subheader("Answer")
 
82
  st.write(response["result"])
83
 
 
 
84
  # Display sources
 
85
  st.subheader("Sources")
 
86
  sources = response["source_documents"]
 
87
  for i, doc in enumerate(sources, 1):
 
88
  with st.expander(f"Source {i}"):
 
89
  st.write(f"**Content:** {doc.page_content}")
 
90
  st.write(f"**URL:** {doc.metadata.get('url', 'No URL available')}")
91
 
 
 
92
  except Exception as e:
 
93
  st.error(f"An error occurred: {e}")
94
 
 
 
95
  if __name__ == "__main__":
 
96
  main()
 
1
  import streamlit as st
2
+
3
  import os
4
+
5
  import json
6
+
7
  from dotenv import load_dotenv
8
 
9
+
10
+
11
+ # from langchain.chains import RetrievalQA
12
+
13
  from langchain_community.vectorstores import FAISS
14
+
15
  from langchain.text_splitter import CharacterTextSplitter
16
+
17
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings, OpenAI
18
+
19
  from langchain.schema import Document
20
+
21
+ from langchain_huggingface import HuggingFaceEmbeddings
22
+
23
+ from langchain.chains.combine_documents import create_stuff_documents_chain
24
+
25
+ from langchain.chains.retrieval import create_retrieval_chain
26
+
27
+ from langchain_core.prompts import PromptTemplate
28
+
29
+
30
 
31
  # Load environment variables
32
+
33
  load_dotenv()
34
 
35
+
36
+
37
  # Get the OpenAI API key from the environment
38
+
39
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
40
+
41
  if not OPENAI_API_KEY:
42
+
43
  st.error("OPENAI_API_KEY is not set. Please add it to your .env file.")
44
 
45
+
46
+
47
  # Initialize session state variables
48
+
49
  if 'vector_store' not in st.session_state:
50
+
51
  st.session_state.vector_store = None
52
+
53
+ # if 'qa_chain' not in st.session_state:
54
+
55
+ # st.session_state.qa_chain = None
56
+
 
 
 
 
 
 
 
57
 
58
+
 
59
 
 
 
60
 
61
+ # def setup_qa_chain(vector_store):
62
+
63
+ # """Set up the QA chain with a retriever."""
64
+
65
+ # retriever = vector_store.as_retriever(search_kwargs={"k": 3})
66
+
67
+ # llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY)
68
+
69
+ # qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, return_source_documents=True)
70
+
71
+ # return qa_chain
72
+
73
+
74
+
75
+ prompt_template = PromptTemplate.from_template("Answer the following query based on a number of context documents Query:{query},Context:{context},Answer:")
76
+
77
+
78
 
79
  def main():
80
+
81
  # Set page title and header
82
+
83
+ llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=OPENAI_API_KEY)
84
+
85
+ st.set_page_config(page_title="LibRAG", page_icon="πŸ“š")
86
+
87
  st.title("Boston Public Library Database πŸ“š")
88
 
89
+
90
+
91
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
92
+
93
+
94
+
95
  # Sidebar for initialization
96
+
97
+ # st.sidebar.header("Initialize Knowledge Base")
98
+
99
+ # if st.sidebar.button("Load Data"):
100
+
101
+ # try:
102
+
103
+ # st.session_state.vector_store = FAISS.load_local(
104
+
105
+ # "vector-store", embeddings, allow_dangerous_deserialization=True
106
+
107
+ # )
108
+
109
+ # st.session_state.qa_chain = setup_qa_chain(st.session_state.vector_store)
110
+
111
+ # st.sidebar.success("Knowledge base loaded successfully!")
112
+
113
+ # except Exception as e:
114
+
115
+ # st.sidebar.error(f"Error loading data: {e}")
116
+
117
+
118
+
119
+ st.session_state.vector_store = FAISS.load_local("vector-store", embeddings, allow_dangerous_deserialization=True)
120
+
121
+ st.session_state.combine_docs_chain = create_stuff_documents_chain(llm, prompt_template)
122
+
123
+ st.session_stateretrieval_chain = create_retrieval_chain(st.session_state.vector_store.as_retriever(search_kwargs={"k": 3}), combine_docs_chain)
124
+
125
+ # st.session_state.qa_chain = setup_qa_chain(st.session_state.vector_store)
126
 
127
  # Query input and processing
128
+
129
  st.header("Ask a Question")
130
+
131
  query = st.text_input("Enter your question about BPL's database")
132
 
133
+ response = llm.invoke()
134
+
135
  if query:
136
+
137
  # Check if vector store and QA chain are initialized
138
+
139
+ if st.session_state.response is None:
140
+
141
  st.warning("Please load the knowledge base first using the sidebar.")
142
+
143
  else:
144
+
145
  # Run the query
146
+
147
  try:
148
+
149
+ st.session_state.response = retrieval_chain.invoke({"input": f"{query}"})
150
+
151
 
152
+
153
  # Display answer
154
+
155
  st.subheader("Answer")
156
+
157
  st.write(response["result"])
158
 
159
+
160
+
161
  # Display sources
162
+
163
  st.subheader("Sources")
164
+
165
  sources = response["source_documents"]
166
+
167
  for i, doc in enumerate(sources, 1):
168
+
169
  with st.expander(f"Source {i}"):
170
+
171
  st.write(f"**Content:** {doc.page_content}")
172
+
173
  st.write(f"**URL:** {doc.metadata.get('url', 'No URL available')}")
174
 
175
+
176
+
177
  except Exception as e:
178
+
179
  st.error(f"An error occurred: {e}")
180
 
181
+
182
+
183
  if __name__ == "__main__":
184
+
185
  main()