Update app.py
Browse filesfrom langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain import HuggingFaceHub
from langchain.chains import RetrievalQA
import streamlit as st
st.set_page_config(page_title = "Hospital Regulatory Chat", page_icon=":hospital:")
DB_FAISS_PATH = '.'
def get_vectorstore():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'})
vector_store = FAISS.load_local(DB_FAISS_PATH, embeddings)
return vector_store
vector_store = get_vectorstore()
llm = HuggingFaceHub(repo_id = "meta-llama/Llama-2-7b-chat-hf", model_kwargs={"temperature":0.5}) #, "max_length":512})
qa_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type='stuff',
retriever=vector_store.as_retriever(search_kwargs={'k': 10}),
#retriever=vector_store.as_retriever(search_kwargs={"score_threshold": .01}),
return_source_documents = True
)
source_dictionary = {"data\CMS_SOMA.pdf":"[CMS State Operations Manual Appendix A](https://www.cms.gov/regulations-and-guidance/guidance/manuals/downloads/som107ap_a_hospitals.pdf)",
"data\DOH-RCW.pdf":"[Revised Code of Washington (RCW) Chapter 70.41](https://app.leg.wa.gov/rcw/default.aspx?cite=70.41)",
"data\WAC 246-320.pdf":"[Washington Administrative Code (WAC) 246-320](https://app.leg.wa.gov/WAC/default.aspx?cite=246-320)"}
with st.container():
st.title("Hospital Regulation Chat")
with st.sidebar:
st.subheader("Find regulations for hospitals in the state of Washington.")
st.markdown("""
We look into three sources to find top ten most relevant excerpts:
- [CMS State Operations Manual Appendix A](https://www.cms.gov/regulations-and-guidance/guidance/manuals/downloads/som107ap_a_hospitals.pdf)
- [Revised Code of Washington (RCW) Chapter 70.41](https://app.leg.wa.gov/rcw/default.aspx?cite=70.41)
- [Washington Administrative Code (WAC) 246-320](https://app.leg.wa.gov/WAC/default.aspx?cite=246-320)
""") #, unsafe_allow_html=True)
st.write("This is tool is meant to assist healthcare workers to the extent it can. Please note that the page numbers may be occasionally slightly off, use the included excerpt to find the reference if this happens.")
st.markdown("**Ask your question and :red[click 'Find Matches'.]**")
prompt = st.text_input("e.g. Should all employees undergo background checks? ")
if (st.button("Find Matches")):
answer = qa_chain({"query":prompt})
n = len(answer['source_documents'])
for i in range(n):
with st.container():
st.subheader(source_dictionary[answer['source_documents'][i].metadata['source']])
page_no = "**Page: " + str(answer['source_documents'][i].metadata['page']) + "**"
st.markdown(page_no)
st.write("...")
st.write(answer['source_documents'][i].page_content)
st.write("...")
st.write('---------------------------------\n\n')
@@ -1,99 +1,63 @@
|
|
1 |
-
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
2 |
-
from langchain import PromptTemplate
|
3 |
-
from langchain import HuggingFaceHub
|
4 |
from langchain.embeddings import HuggingFaceEmbeddings
|
5 |
from langchain.vectorstores import FAISS
|
6 |
-
from langchain
|
7 |
from langchain.chains import RetrievalQA
|
8 |
-
import chainlit as cl
|
9 |
import streamlit as st
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
custom_prompt_template = """Use the following pieces of information to answer the user's question.
|
14 |
-
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
15 |
-
|
16 |
-
Context: {context}
|
17 |
-
Question: {question}
|
18 |
|
19 |
-
Only return the helpful answer below and nothing else.
|
20 |
-
Helpful answer:
|
21 |
-
"""
|
22 |
-
|
23 |
-
def set_custom_prompt():
|
24 |
-
"""
|
25 |
-
Prompt template for QA retrieval for each vectorstore
|
26 |
-
"""
|
27 |
-
prompt = PromptTemplate(template=custom_prompt_template,
|
28 |
-
input_variables=['context', 'question'])
|
29 |
-
return prompt
|
30 |
-
|
31 |
-
#Retrieval QA Chain
|
32 |
-
def retrieval_qa_chain(llm, prompt, db):
|
33 |
-
qa_chain = RetrievalQA.from_chain_type(llm=llm,
|
34 |
-
chain_type='stuff',
|
35 |
-
retriever=db.as_retriever(search_kwargs={'k': 2}),
|
36 |
-
return_source_documents=True,
|
37 |
-
chain_type_kwargs={'prompt': prompt}
|
38 |
-
)
|
39 |
-
return qa_chain
|
40 |
|
41 |
-
|
42 |
-
def load_llm():
|
43 |
-
# Load the locally downloaded model here
|
44 |
-
llm = HuggingFaceHub(repo_id = "meta-llama/Llama-2-7b-chat-hf", model_kwargs={"temperature":0.5}) #, "max_length":512})
|
45 |
-
return llm
|
46 |
|
47 |
-
|
48 |
-
def qa_bot():
|
49 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
|
50 |
model_kwargs={'device': 'cpu'})
|
51 |
-
|
52 |
-
|
53 |
-
qa_prompt = set_custom_prompt()
|
54 |
-
qa = retrieval_qa_chain(llm, qa_prompt, db)
|
55 |
-
|
56 |
-
return qa
|
57 |
-
|
58 |
-
#output function
|
59 |
-
def final_result(query):
|
60 |
-
qa_result = qa_bot()
|
61 |
-
response = qa_result({'query': query})
|
62 |
-
return response
|
63 |
-
|
64 |
|
65 |
-
|
66 |
|
67 |
-
|
68 |
-
if user_question:
|
69 |
-
st.write("Your Question: ",user_question)
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
await msg.send()
|
78 |
-
msg.content = "Hi, Welcome to Medical Bot. What is your query?"
|
79 |
-
await msg.update()
|
80 |
-
|
81 |
-
cl.user_session.set("chain", chain)
|
82 |
-
|
83 |
-
@cl.on_message
|
84 |
-
async def main(message):
|
85 |
-
chain = cl.user_session.get("chain")
|
86 |
-
cb = cl.AsyncLangchainCallbackHandler(
|
87 |
-
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
|
88 |
-
)
|
89 |
-
cb.answer_reached = True
|
90 |
-
res = await chain.acall(message, callbacks=[cb])
|
91 |
-
answer = res["result"]
|
92 |
-
sources = res["source_documents"]
|
93 |
|
94 |
-
if sources:
|
95 |
-
answer += f"\nSources:" + str(sources)
|
96 |
-
else:
|
97 |
-
answer += "\nNo sources found"
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langchain.embeddings import HuggingFaceEmbeddings
|
2 |
from langchain.vectorstores import FAISS
|
3 |
+
from langchain import HuggingFaceHub
|
4 |
from langchain.chains import RetrievalQA
|
|
|
5 |
import streamlit as st
|
6 |
|
7 |
+
st.set_page_config(page_title = "Hospital Regulatory Chat", page_icon=":hospital:")
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
DB_FAISS_PATH = '.'
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def get_vectorstore():
|
|
|
13 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
|
14 |
model_kwargs={'device': 'cpu'})
|
15 |
+
vector_store = FAISS.load_local(DB_FAISS_PATH, embeddings)
|
16 |
+
return vector_store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
vector_store = get_vectorstore()
|
19 |
|
20 |
+
llm = HuggingFaceHub(repo_id = "meta-llama/Llama-2-7b-chat-hf",model_kwargs={"temperature":0.5}) #, "max_length":512})
|
|
|
|
|
21 |
|
22 |
+
qa_chain = RetrievalQA.from_chain_type(llm=llm,
|
23 |
+
chain_type='stuff',
|
24 |
+
retriever=vector_store.as_retriever(search_kwargs={'k': 10}),
|
25 |
+
#retriever=vector_store.as_retriever(search_kwargs={"score_threshold": .01}),
|
26 |
+
return_source_documents = True
|
27 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
source_dictionary = {"data\CMS_SOMA.pdf":"[CMS State Operations Manual Appendix A](https://www.cms.gov/regulations-and-guidance/guidance/manuals/downloads/som107ap_a_hospitals.pdf)",
|
31 |
+
"data\DOH-RCW.pdf":"[Revised Code of Washington (RCW) Chapter 70.41](https://app.leg.wa.gov/rcw/default.aspx?cite=70.41)",
|
32 |
+
"data\WAC 246-320.pdf":"[Washington Administrative Code (WAC) 246-320](https://app.leg.wa.gov/WAC/default.aspx?cite=246-320)"}
|
33 |
+
|
34 |
+
with st.container():
|
35 |
+
st.title("Hospital Regulation Chat")
|
36 |
+
|
37 |
+
with st.sidebar():
|
38 |
+
st.subheader("Find regulations for hospitals in the state of Washington.")
|
39 |
+
st.markdown("""
|
40 |
+
We look into three sources to find top ten most relevant excerpts:
|
41 |
+
- [CMS State Operations Manual Appendix A](https://www.cms.gov/regulations-and-guidance/guidance/manuals/downloads/som107ap_a_hospitals.pdf)
|
42 |
+
- [Revised Code of Washington (RCW) Chapter 70.41](https://app.leg.wa.gov/rcw/default.aspx?cite=70.41)
|
43 |
+
- [Washington Administrative Code (WAC) 246-320](https://app.leg.wa.gov/WAC/default.aspx?cite=246-320)
|
44 |
+
""") #, unsafe_allow_html=True)
|
45 |
+
st.write("This is tool is meant to assist healthcare workers to the extent it can. Please note that the page numbers may be occasionally slightly off, use the matching excerpts to find the reference if this happens.")
|
46 |
+
|
47 |
+
st.markdown("**Ask your question and :red[click 'Find Excerpts'.]**")
|
48 |
+
prompt = st.text_input("e.g. Should all employees undergo background checks? ")
|
49 |
+
|
50 |
+
if (st.button("Find Excerpts")):
|
51 |
+
answer = qa_chain({"query":prompt})
|
52 |
+
|
53 |
+
n = len(answer['source_documents'])
|
54 |
+
|
55 |
+
for i in range(n):
|
56 |
+
with st.container():
|
57 |
+
st.subheader(source_dictionary[answer['source_documents'][i].metadata['source']])
|
58 |
+
page_no = "**Page: " + str(answer['source_documents'][i].metadata['page']) + "**"
|
59 |
+
st.markdown(page_no)
|
60 |
+
st.write("...")
|
61 |
+
st.write(answer['source_documents'][i].page_content)
|
62 |
+
st.write("...")
|
63 |
+
st.write('---------------------------------\n\n')
|