|
import streamlit as st |
|
import os |
|
import asyncio |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_together import Together |
|
from langchain_community.chat_message_histories import StreamlitChatMessageHistory |
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
|
|
|
llm = Together( |
|
model="mistralai/Mixtral-8x22B-Instruct-v0.1", |
|
temperature=0.2, |
|
top_k=12, |
|
max_tokens=22048, |
|
together_api_key=os.environ['pilotikval'] |
|
) |
|
|
|
|
|
store = {} |
|
|
|
model_name = "BAAI/bge-base-en" |
|
encode_kwargs = {'normalize_embeddings': True} |
|
|
|
embedding_function = HuggingFaceBgeEmbeddings( |
|
model_name=model_name, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
def get_session_history(session_id: str) -> BaseChatMessageHistory: |
|
if session_id not in store: |
|
store[session_id] = StreamlitChatMessageHistory(key=session_id) |
|
return store[session_id] |
|
|
|
|
|
def app(): |
|
with st.sidebar: |
|
st.title("dochatter") |
|
option = st.selectbox( |
|
'Which retriever would you like to use?', |
|
('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine') |
|
) |
|
|
|
|
|
persist_directory = { |
|
'General Medicine': "./oxfordmedbookdir/", |
|
'Respiratory1': "./respfishmandbcud/", |
|
'Respiratory2': "./respmurray/", |
|
'Med2.2': "./medmrcp2store/", |
|
'Med2.1': "./mrcpchromadb/" |
|
}.get(option, "./mrcpchromadb/") |
|
|
|
collection_name = { |
|
'General Medicine': "oxfordmed", |
|
'Respiratory1': "fishmannotescud", |
|
'Respiratory2': "respmurraynotes", |
|
'Med2.2': "medmrcp2notes", |
|
'Med2.1': "mrcppassmednotes" |
|
}.get(option, "mrcppassmednotes") |
|
|
|
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name) |
|
retriever = vectordb.as_retriever(search_kwargs={"k": 5}) |
|
|
|
|
|
contextualize_q_system_prompt = ( |
|
"Given a chat history and the latest user question " |
|
"which might reference context in the chat history, " |
|
"formulate a standalone question which can be understood " |
|
"without the chat history. Do NOT answer the question, " |
|
"just reformulate it if needed and otherwise return it as is." |
|
) |
|
contextualize_q_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", contextualize_q_system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
history_aware_retriever = create_history_aware_retriever( |
|
llm, retriever, contextualize_q_prompt |
|
) |
|
|
|
system_prompt = ( |
|
"You are helping a doctor. Be as detailed and thorough as possible " |
|
"Use the following pieces of retrieved context to answer " |
|
"the question. If you don't know the answer, say that you " |
|
"don't know." |
|
"\n\n" |
|
"{context}" |
|
) |
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) |
|
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) |
|
|
|
|
|
conversational_rag_chain = RunnableWithMessageHistory( |
|
rag_chain, |
|
get_session_history, |
|
input_messages_key="input", |
|
history_messages_key="chat_history", |
|
output_messages_key="answer", |
|
) |
|
|
|
|
|
if "messages" not in st.session_state.keys(): |
|
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}] |
|
|
|
st.header("Hello Doc!") |
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
|
|
prompts2 = st.chat_input("Say something") |
|
|
|
if prompts2: |
|
st.session_state.messages.append({"role": "user", "content": prompts2}) |
|
with st.chat_message("user"): |
|
st.write(prompts2) |
|
|
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
final_response = conversational_rag_chain.invoke( |
|
{ |
|
"input": prompts2, |
|
}, |
|
config={"configurable": {"session_id": "current_session"}} |
|
) |
|
st.write(final_response['answer']) |
|
st.session_state.messages.append({"role": "assistant", "content": final_response['answer']}) |
|
|
|
if __name__ == '__main__': |
|
app() |
|
|